diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml index b5515aa970..c3799b1940 100644 --- a/.github/ISSUE_TEMPLATE/400-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -74,28 +74,21 @@ body: If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: ```python - from vllm_omni import OmniLLM, create_ar_stage_config, create_dit_stage_config - - # Create stage configurations - ar_config = create_ar_stage_config( - stage_id=0, - model_path="Qwen/Qwen3-0.6B", - input_modalities=["text"], - output_modalities=["text"] - ) + from vllm_omni.entrypoints.omni import Omni + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + from vllm import SamplingParams - dit_config = create_dit_stage_config( - stage_id=1, - model_path="stabilityai/stable-diffusion-2-1", - input_modalities=["text"], - output_modalities=["image"] + omni = Omni( + model="Qwen/Qwen-Image", + stage_configs_path="/path/to/stage_configs.yaml", ) - # Initialize OmniLLM - omni = OmniLLM([ar_config, dit_config]) - - # Generate - outputs = omni.generate(prompt="A scenic watercolor painting of a lighthouse at sunset") + prompts = [{"prompt": "A scenic watercolor painting of a lighthouse at sunset"}] + sampling_params_list = [ + SamplingParams(max_tokens=1), + OmniDiffusionSamplingParams(num_outputs_per_prompt=1), + ] + outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list) ``` If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. diff --git a/benchmarks/qwen3-omni/README.md b/benchmarks/qwen3-omni/README.md index f82e20a419..de27c05c2c 100644 --- a/benchmarks/qwen3-omni/README.md +++ b/benchmarks/qwen3-omni/README.md @@ -56,9 +56,9 @@ What it does: - Runs `examples/offline_inference/qwen3_omni/end2end.py` with `--log-stats`. - Uses `benchmarks/build_dataset/top100.txt` and writes to: - Logs: `benchmarks/qwen3-omni/vllm_omni/logs/` - - `omni_llm_pipeline_text.orchestrator.stats.jsonl` — per-stage latency stats. - - `omni_llm_pipeline_text.overall.stats.jsonl` — end-to-end latency/TPS. - - `omni_llm_pipeline_text.stage{0,1,2}.log` — per-stage detailed logs/errors. + - `omni_pipeline_text.orchestrator.stats.jsonl` — per-stage latency stats. + - `omni_pipeline_text.overall.stats.jsonl` — end-to-end latency/TPS. + - `omni_pipeline_text.stage{0,1,2}.log` — per-stage detailed logs/errors. - Outputs: `benchmarks/qwen3-omni/vllm_omni/outputs/` — ~100 text and `.wav` files. Key checks: diff --git a/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh b/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh index 935753605a..e4c83e9751 100644 --- a/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh +++ b/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh @@ -26,12 +26,12 @@ else --log-stats \ --log-dir $log_dir echo "Logs and outputs are saved in ${log_dir} and ${outputs_dir} respectively:" - echo " - omni_llm_pipeline_text run dir/base name" - echo " - omni_llm_pipeline_text.orchestrator.stats.jsonl orchestrator-stage latency stats" - echo " - omni_llm_pipeline_text.overall.stats.jsonl overall latency/TPS stats" - echo " - omni_llm_pipeline_text.stage0.log per-stage detailed logs" - echo " - omni_llm_pipeline_text.stage1.log" - echo " - omni_llm_pipeline_text.stage2.log" + echo " - omni_pipeline_text run dir/base name" + echo " - omni_pipeline_text.orchestrator.stats.jsonl orchestrator-stage latency stats" + echo " - omni_pipeline_text.overall.stats.jsonl overall latency/TPS stats" + echo " - omni_pipeline_text.stage0.log per-stage detailed logs" + echo " - omni_pipeline_text.stage1.log" + echo " - omni_pipeline_text.stage2.log" echo "Key checks: overall.stats.jsonl for end-to-end latency/TPS; orchestrator.stats.jsonl for stable per-stage latency; stage*.log for errors or long tails." echo " - outputs/ Generated txt and wav files, there should be 100 text and wav files generated respectively" fi diff --git a/docs/api/README.md b/docs/api/README.md index 1b90b022c0..549d8bb1db 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -6,19 +6,13 @@ Main entry points for vLLM-Omni inference and serving. - [vllm_omni.entrypoints.async_omni.AsyncOmni][] - [vllm_omni.entrypoints.async_omni_diffusion.AsyncOmniDiffusion][] -- [vllm_omni.entrypoints.async_omni_llm.AsyncOmniLLM][] - [vllm_omni.entrypoints.cli.benchmark.base.OmniBenchmarkSubcommandBase][] - [vllm_omni.entrypoints.cli.benchmark.main.OmniBenchmarkSubcommand][] - [vllm_omni.entrypoints.cli.benchmark.serve.OmniBenchmarkServingSubcommand][] - [vllm_omni.entrypoints.cli.serve.OmniServeCommand][] - [vllm_omni.entrypoints.client_request_state.ClientRequestState][] - [vllm_omni.entrypoints.omni.Omni][] -- [vllm_omni.entrypoints.omni.OmniBase][] -- [vllm_omni.entrypoints.omni_diffusion.OmniDiffusion][] -- [vllm_omni.entrypoints.omni_llm.OmniLLM][] -- [vllm_omni.entrypoints.omni_stage.OmniStage][] -- [vllm_omni.entrypoints.stage_utils.OmniStageTaskType][] -- [vllm_omni.entrypoints.zmq_utils.ZmqQueue][] +- [vllm_omni.entrypoints.omni_base.OmniBase][] ## Inputs @@ -48,9 +42,7 @@ Engine classes for offline and online inference. - [vllm_omni.engine.OmniEngineCoreOutputs][] - [vllm_omni.engine.OmniEngineCoreRequest][] - [vllm_omni.engine.PromptEmbedsPayload][] -- [vllm_omni.engine.arg_utils.AsyncOmniEngineArgs][] - [vllm_omni.engine.arg_utils.OmniEngineArgs][] -- [vllm_omni.engine.input_processor.OmniInputProcessor][] - [vllm_omni.engine.output_processor.MultimodalOutputProcessor][] - [vllm_omni.engine.output_processor.OmniRequestState][] diff --git a/docs/configuration/stage_configs.md b/docs/configuration/stage_configs.md index ed69cdfe17..d8a26d516f 100644 --- a/docs/configuration/stage_configs.md +++ b/docs/configuration/stage_configs.md @@ -18,7 +18,7 @@ If users want to modify some part of it. The custom stage_configs file can be in For offline (Assume necessary dependencies have ben imported): ```python model_name = "Qwen/Qwen2.5-Omni-7B" -omni_llm = Omni(model=model_name, stage_configs_path="/path/to/custom_stage_configs.yaml") +omni = Omni(model=model_name, stage_configs_path="/path/to/custom_stage_configs.yaml") ``` For online serving: @@ -30,7 +30,7 @@ vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 --stage-configs-path /path/to Below is a specific example of stage_configs.yaml in Qwen2.5-omni. ```python -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni with AsyncOmniEngine + Orchestrator runtime. stage_args: - stage_id: 0 # mark the unique id for each stage runtime: # The disaggregated configuration diff --git a/docs/configuration/stage_configs/qwen2_5_omni.yaml b/docs/configuration/stage_configs/qwen2_5_omni.yaml index e20e79c3e9..416ad137d9 100644 --- a/docs/configuration/stage_configs/qwen2_5_omni.yaml +++ b/docs/configuration/stage_configs/qwen2_5_omni.yaml @@ -1,4 +1,4 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni with AsyncOmniEngine + Orchestrator runtime. stage_args: - stage_id: 0 runtime: diff --git a/docs/contributing/README.md b/docs/contributing/README.md index 77d96066ae..8a5bcfff0a 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -107,7 +107,7 @@ Only specific types of PRs will be reviewed. The PR title is prefixed appropriat - `[CI/Build]` for build or continuous integration improvements. - `[Doc]` for documentation fixes and improvements. - `[Model]` for adding a new model or improving an existing model. Model name should appear in the title. -- `[Frontend]` For changes on the vLLM-Omni frontend (e.g., OpenAI API server, `OmniLLM` class, etc.) +- `[Frontend]` For changes on the vLLM-Omni frontend (e.g., OpenAI API server, `Omni`/`AsyncOmni`, etc.) - `[Kernel]` for changes affecting CUDA kernels or other compute kernels. - `[Core]` for changes in the core vLLM-Omni logic (e.g., `OmniProcessor`, `OmniARScheduler`, etc.) - `[Hardware][Vendor]` for hardware-specific changes. Vendor name should appear in the prefix, such as [Ascend] for Ascend NPUs. diff --git a/docs/contributing/ci/CI_5levels.md b/docs/contributing/ci/CI_5levels.md index 54af850bbc..03b907f323 100644 --- a/docs/contributing/ci/CI_5levels.md +++ b/docs/contributing/ci/CI_5levels.md @@ -168,11 +168,6 @@ vllm_omni/ tests/ │ └── arg_utils.py │ └── test_arg_utils.py ⬜ │ ├── entrypoints/ → ├── entrypoints/ -│ ├── omni.py │ ├── test_omni.py ⬜ (E2E covered by e2e/offline, e2e/online) -│ ├── omni_llm.py │ ├── test_omni_llm.py ✅ -│ ├── omni_stage.py │ ├── test_omni_stage.py ⬜ (partial in test_omni_stage_diffusion_config.py) -│ ├── omni_diffusion.py │ ├── test_omni_diffusion.py ✅ -│ ├── async_omni.py │ ├── test_async_omni.py ✅ actually in e2e/online_serving/test_async_omni.py │ ├── async_omni_diffusion.py │ ├── test_async_omni_diffusion_config.py ✅ │ ├── stage_utils.py │ ├── test_stage_utils.py ✅ │ ├── cli/ │ ├── cli/ (benchmarks/test_serve_cli.py covers CLI serve) diff --git a/docs/contributing/ci/tests_style.md b/docs/contributing/ci/tests_style.md index a24f718647..5d642fdb95 100644 --- a/docs/contributing/ci/tests_style.md +++ b/docs/contributing/ci/tests_style.md @@ -24,8 +24,6 @@ End-to-end tests verify the complete functionality of a system or component. For - **`tests/e2e/online_serving/`**: Tests for online serving scenarios (e.g., API server tests) -**Example:** The test file for `vllm_omni/entrypoints/omni_llm.py` should be located at `tests/entrypoints/test_omni_llm.py`. - ## Test Directory Structure The ideal directory structure mirrors the source code organization. Legend: `✅` = test exists, `⬜` = suggested to add. @@ -75,11 +73,6 @@ vllm_omni/ tests/ │ └── arg_utils.py │ └── test_arg_utils.py ⬜ │ ├── entrypoints/ → ├── entrypoints/ -│ ├── omni.py │ ├── test_omni.py ⬜ (E2E covered by e2e/offline, e2e/online) -│ ├── omni_llm.py │ ├── test_omni_llm.py ✅ -│ ├── omni_stage.py │ ├── test_omni_stage.py ⬜ (partial in test_omni_stage_diffusion_config.py) -│ ├── omni_diffusion.py │ ├── test_omni_diffusion.py ✅ -│ ├── async_omni.py │ ├── test_async_omni.py ✅ actually in e2e/online_serving/test_async_omni.py │ ├── async_omni_diffusion.py │ ├── test_async_omni_diffusion_config.py ✅ │ ├── stage_utils.py │ ├── test_stage_utils.py ✅ │ ├── cli/ │ ├── cli/ (benchmarks/test_serve_cli.py covers CLI serve) @@ -170,7 +163,7 @@ vllm_omni/ tests/ ### Naming Conventions -- **Unit Tests**: Use `test_.py` format. Example: `omni_llm.py` → `test_omni_llm.py` +- **Unit Tests**: Use `test_.py` format. Example: `stage_utils.py` → `test_stage_utils.py` - **E2E Tests**: Place in `tests/e2e/offline_inference/` or `tests/e2e/online_serving/` with descriptive names. Example: `tests/e2e/offline_inference/test_qwen3_omni.py`, `tests/e2e/offline_inference/test_diffusion_model.py` diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md index 6216b4b785..a0619e3381 100644 --- a/docs/contributing/model/adding_omni_model.md +++ b/docs/contributing/model/adding_omni_model.md @@ -330,54 +330,46 @@ Stage transitions are the mechanism by which outputs from one stage are converte ### Where Stage Transitions Are Called -Stage transitions happen automatically in the orchestrator (`OmniLLM` class) during the generation loop. Here's the detailed flow: +Stage transitions happen automatically in the runtime orchestrator. Here's the detailed flow: -1. **Location**: `vllm_omni/entrypoints/omni_llm.py` in the `_run_generation()` method +1. **Location**: `vllm_omni/engine/orchestrator.py` in `_forward_to_next_stage()` 2. **Trigger**: When a stage completes processing and produces outputs 3. **Execution Flow**: ```python - # In omni_llm.py, _run_generation() method (around line 345-460) - - # Main orchestrator loop polls each stage for completed requests - for stage_id, stage in enumerate(self.stage_list): - result = stage.try_collect() # Get completed request - if result is None: - continue - - # Store outputs from this stage - engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") - stage.set_engine_outputs(engine_outputs) - - # Check if there's a next stage to forward to - next_stage_id = stage_id + 1 - if next_stage_id < len(self.stage_list): - next_stage: OmniStage = self.stage_list[next_stage_id] - - # THIS IS WHERE STAGE TRANSITION HAPPENS - next_inputs = next_stage.process_engine_inputs( - self.stage_list, - [request_id_to_prompt[req_id]] - ) - - # Submit to next stage - task = { - "type": OmniStageTaskType.GENERATE, - "request_id": req_id, - "engine_inputs": next_inputs[0], - "sampling_params": sampling_params_list[next_stage_id], - } - next_stage.submit(task) + # In orchestrator.py + next_stage_id = stage_id + 1 + next_client = self.stage_clients[next_stage_id] + params = req_state.sampling_params_list[next_stage_id] + + # Save current stage outputs so stage_input_processors can consume them. + self.stage_clients[stage_id].set_engine_outputs([output]) + + # THIS IS WHERE STAGE TRANSITION HAPPENS + next_inputs = next_client.process_engine_inputs( + stage_list=self.stage_clients, + prompt=req_state.prompt, + ) + + # Build and submit request(s) to the next stage. + for next_input in next_inputs: + request = build_engine_core_request_from_tokens( + request_id=req_id, + prompt=next_input, + params=params, + model_config=self.stage_vllm_configs[next_stage_id].model_config, + ) + await next_client.add_request_async(request) ``` ### How Stage Transitions Work The stage transition process follows these steps: -1. **Stage Completion**: When a stage finishes processing a request, it stores outputs via `stage.set_engine_outputs(engine_outputs)` +1. **Stage Completion**: When a stage finishes processing a request, the orchestrator stores outputs via `stage_client.set_engine_outputs(...)` 2. **Transition Detection**: The orchestrator checks if there's a next stage and calls `process_engine_inputs()` on it -3. **Input Processing**: The `process_engine_inputs()` method in `OmniStage` (`omni_stage.py`) handles the transition: +3. **Input Processing**: The stage input processor configured in stage YAML (under `vllm_omni/model_executor/stage_input_processors/`) handles the transition: ```python def process_engine_inputs( self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 3e54aed8f9..8dbe4c0f6d 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -23,30 +23,32 @@ export VLLM_PROFILER_MAX_ITERS=1 The profiler is default to function across all stages. But It is highly recommended to profile specific stages by passing the stages list, preventing from producing too large trace files: ```python # Profile all stages -omni_llm.start_profile() +omni.start_profile() # Only profile Stage 1 -omni_llm.start_profile(stages=[1]) +omni.start_profile(stages=[1]) ``` ```python # Stage 0 (Thinker) and Stage 2 (Audio Decoder) for qwen omni -omni_llm.start_profile(stages=[0, 2]) +omni.start_profile(stages=[0, 2]) ``` **Python Usage**: Wrap your generation logic with `start_profile()` and `stop_profile()`. ```python -from vllm_omni import omni_llm +from vllm_omni.entrypoints.omni import Omni + +omni = Omni(model="Qwen/Qwen3-Omni-30B-A3B-Instruct") profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) # 1. Start profiling if enabled if profiler_enabled: - omni_llm.start_profile(stages=[0]) + omni.start_profile(stages=[0]) # Initialize generator -omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator) +omni_generator = omni.generate(prompts, sampling_params_list, py_generator=args.py_generator) total_requests = len(prompts) processed_count = 0 @@ -57,21 +59,21 @@ for stage_outputs in omni_generator: # ... [Output processing logic for text/audio would go here] ... # Update count to track when to stop profiling - processed_count += len(stage_outputs.request_output) + processed_count += 1 # 2. Check if all requests are done to stop the profiler safely if profiler_enabled and processed_count >= total_requests: print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") # Stop the profiler while workers are still active - omni_llm.stop_profile() + omni.stop_profile() # Wait for traces to flush to disk print("[Info] Waiting 30s for workers to write trace files to disk...") time.sleep(30) print("[Info] Trace export wait time finished.") -omni_llm.close() +omni.close() ``` diff --git a/docs/design/architecture_overview.md b/docs/design/architecture_overview.md index 2e58320ade..1c38ba6718 100644 --- a/docs/design/architecture_overview.md +++ b/docs/design/architecture_overview.md @@ -67,12 +67,12 @@ According to analysis for current popular open-source models, most of them have | Component | Description | | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | | **OmniRouter** | provide an intelligent router for Omni-modality requests dispatch | -| **EntryPoints** | define the APIs for offline/online serving (APIServer, Omni/AsyncOmni) and provide the OmniStage abstraction for different AR/DiT stages | +| **EntryPoints** | define the APIs for offline/online serving (APIServer, Omni/AsyncOmni), while `AsyncOmniEngine` and `Orchestrator` coordinate multi-stage AR/DiT execution | | **AR** | adapted for omni-modality models while inheriting efficient features from vLLM, such as cache management | | **Diffusion** | natively implemented and optimized using acceleration components | | **OmniConnector** | supports fully disaggregation based on E/P/D/G (Encoding/Processing/Decoding/Generation) disaggregation across stages | -Disaggregated stages are managed through configuration, such as in the Qwen3-Omni example, where stages like Thinker, Talker, and Code2wav are defined as separate OmniStage instances with specific resources and input/output type. +Disaggregated stages are managed through stage configuration. In Qwen3-Omni, Thinker/Talker/Code2wav are declared as separate configured stages, and runtime routing is handled by `Orchestrator` over `StageEngineCoreClient` / `StageDiffusionClient`. ## Main features @@ -127,10 +127,10 @@ Taking **Qwen3-Omni** as an example: The **Omni** class provides a Python interface for offline batched inference. Users initialize the Omni class with a Hugging Face model name and use the generate method, passing inputs that include both text prompts and multi-modal data: ``` -# Create an omni_lm with HF model name. +# Create an omni runtime with HF model name. from vllm_omni.entrypoints.omni import Omni -omni_lm = Omni(model="Qwen/Qwen3-Omni-30B-A3B-Instruct") +omni = Omni(model="Qwen/Qwen3-Omni-30B-A3B-Instruct") # Example prompts. om_inputs = {"prompt": prompt, @@ -140,7 +140,7 @@ om_inputs = {"prompt": prompt, }} # Generate texts and audio from the multi-modality inputs. -outputs = omni_lm.generate(om_inputs, sampling_params_list) +outputs = omni.generate(om_inputs, sampling_params_list) ``` ## Online Serving diff --git a/docs/design/module/ar_module.md b/docs/design/module/ar_module.md index c0f7cddf04..5e0aa5b071 100644 --- a/docs/design/module/ar_module.md +++ b/docs/design/module/ar_module.md @@ -85,9 +85,6 @@ classDiagram class InputProcessor { +process_inputs() EngineCoreRequest } - class OmniInputProcessor { - +process_inputs() OmniEngineCoreRequest - } class VLLMOutputProcessor { +process_outputs() OutputProcessorOutput @@ -96,7 +93,6 @@ classDiagram +process_outputs() OutputProcessorOutput +_route_and_normalize() } - InputProcessor <|-- OmniInputProcessor VLLMOutputProcessor <|-- MultimodalOutputProcessor ``` @@ -105,7 +101,7 @@ classDiagram - **Scheduler**: `OmniARScheduler` extends `vllm.v1.core.sched.scheduler.Scheduler` to enrich scheduled requests with omni-specific payloads - **Worker**: `GPUARWorker` extends `vllm.v1.worker.gpu_worker.Worker` to initialize AR-specific model runners - **ModelRunner**: `GPUARModelRunner` extends `OmniGPUModelRunner` → `vllm.v1.worker.gpu_model_runner.GPUModelRunner` to expose hidden states and handle multimodal outputs -- **InputProcessor**: `OmniInputProcessor` extends `vllm.v1.engine.input_processor.InputProcessor` to serialize prompt embeddings and additional information +- **InputProcessor**: Stage-0 uses upstream `vllm.v1.engine.input_processor.InputProcessor`; `AsyncOmniEngine` then restores omni-specific payloads (for example `additional_information` and `prompt_embeds`) when building `OmniEngineCoreRequest` - **OutputProcessor**: `MultimodalOutputProcessor` extends `vllm.v1.engine.output_processor.OutputProcessor` to route and accumulate multimodal outputs ## 3. Scheduler Design @@ -118,7 +114,7 @@ The following diagram illustrates the request flow through the AR module compone ```mermaid flowchart TD - A[OmniInputProcessor] -->|OmniEngineCoreRequest| B[OmniARScheduler] + A[InputProcessor stage-0 in AsyncOmniEngine] -->|EngineCoreRequest then upgraded to OmniEngineCoreRequest| B[OmniARScheduler] B -->|schedule: OmniNewRequestData| C[GPUARWorker] C -->|SchedulerOutput| D[GPUARModelRunner] D -->|execute_model: None| E[Model Forward Pass] @@ -297,15 +293,17 @@ The input/output processing pipeline handles serialization, routing, and accumul ```mermaid sequenceDiagram participant Client - participant OmniInputProcessor + participant AsyncOmniEngine + participant InputProcessor participant Scheduler participant ModelRunner participant MultimodalOutputProcessor participant Client - Client->>OmniInputProcessor: prompt + prompt_embeds + additional_info - OmniInputProcessor->>OmniInputProcessor: Serialize tensors to payloads - OmniInputProcessor->>Scheduler: OmniEngineCoreRequest (with payloads) + Client->>AsyncOmniEngine: prompt + prompt_embeds + additional_info + AsyncOmniEngine->>InputProcessor: process_inputs() + InputProcessor->>Scheduler: EngineCoreRequest + AsyncOmniEngine->>AsyncOmniEngine: _upgrade_to_omni_request() + serialize_additional_information() Scheduler->>ModelRunner: OmniNewRequestData (with payloads) ModelRunner->>ModelRunner: Decode payloads → CPU tensors ModelRunner->>ModelRunner: Overlay prompt_embeds on inputs_embeds @@ -318,43 +316,18 @@ sequenceDiagram MultimodalOutputProcessor->>Client: RequestOutput (with multimodal_output) ``` -### OmniInputProcessor - -`OmniInputProcessor` extends the base input processor to serialize prompt embeddings and additional information for inter-stage transfer. - -#### Payload Serialization +### Stage-0 Input Processing -Converts PyTorch tensors to serialized payloads: +Stage-0 now uses upstream `InputProcessor` directly, and `AsyncOmniEngine` upgrades the request to `OmniEngineCoreRequest` while restoring omni-specific payloads. ```python -def process_inputs(self, ...) -> OmniEngineCoreRequest: - # Serialize prompt_embeds - if "prompt_embeds" in decoder_inputs: - pe_cpu = decoder_inputs["prompt_embeds"].detach().to("cpu").contiguous() - prompt_embeds_payload = PromptEmbedsPayload( - data=pe_cpu.numpy().tobytes(), - shape=[seq_len, hidden_size], - dtype=dtype_str, - ) - - # Serialize additional_information - if "additional_information" in decoder_inputs: - entries = {} - for key, value in raw_info.items(): - if isinstance(value, torch.Tensor): - entry = AdditionalInformationEntry( - tensor_data=value.numpy().tobytes(), - tensor_shape=list(value.shape), - tensor_dtype=dtype_str, - ) - entries[key] = entry - additional_information_payload = AdditionalInformationPayload(entries=entries) - - return OmniEngineCoreRequest( - # ... standard fields ... - prompt_embeds=prompt_embeds_payload, - additional_information=additional_information_payload, - ) +request = self.input_processor.process_inputs( + request_id=request_id, + prompt=prompt, + params=params, + supported_tasks=self.supported_tasks, +) +request = _upgrade_to_omni_request(request, prompt) ``` ### MultimodalOutputProcessor @@ -400,13 +373,13 @@ The AR module of vLLM-Omni extends vLLM through strategic inheritance and minima ### Key Design Patterns 1. **Inheritance over composition**: Extends vLLM classes to preserve compatibility with existing scheduling, batching, and execution mechanisms -2. **Payload serialization**: Uses `PromptEmbedsPayload` and `AdditionalInformationPayload` for efficient inter-stage data transfer +2. **Payload serialization**: Uses serialized `additional_information` payloads together with prompt-embedding handoff for efficient inter-stage transfer 3. **Two-phase execution**: Maintains vLLM's execute/sample separation for AR models while supporting single-phase execution for generation models 4. **Multimodal routing**: Routes outputs by `output_type` and accumulates tensors incrementally to support streaming ### Differences from vLLM -- **Payload support**: Serialized prompt embeddings and additional information enable direct transfer between pipeline stages +- **Payload support**: Serialized additional information and prompt embeddings enable direct transfer between pipeline stages - **Multimodal handling**: Extended input/output processors support images, audio, and other modalities alongside text - **Hidden state exposure**: AR model runners expose per-request hidden states via `pooler_output` for downstream consumption - **Generation scheduler**: Fast-path scheduling for basic heterogeneous architectures that complete in one step diff --git a/docs/design/module/async_omni_architecture.md b/docs/design/module/async_omni_architecture.md new file mode 100644 index 0000000000..59275c556f --- /dev/null +++ b/docs/design/module/async_omni_architecture.md @@ -0,0 +1,203 @@ +# AsyncOmni Architecture (Qwen3-Omni Example) + +## 1. System Architecture + +```text +• ┌─────────────────────────────────────────────────────────────────────────────────┐ + │ API Layer │ + │ ┌─────────────────────────────────────┐ ┌──────────────────────────────────┐ │ + │ │ AsyncOmni (EngineClient) │ │ Omni │ │ + │ │ • generate() / abort() / shutdown() │ │ • generate() │ │ + │ │ • _final_output_handler() │ │ | │ + │ └─────────────────────────────────────┘ └──────────────────────────────────┘ │ + ├─────────────────────────────────────────────────────────────────────────────────┤ + │ Engine Layer (Proxy) │ + │ ┌───────────────────────────────────────────────────────────────────────────┐ │ + │ │ AsyncOmniEngine │ │ + │ │ • _bootstrap_orchestrator() & _initialize_stages() │ │ + │ │ • add_request() / add_request_async() -> input_processor.process_inputs() │ │ + │ │ • try_get_output() / try_get_output_async() │ │ + │ └───────────────────┬─────────────────────────────────▲─────────────────────┘ │ + │ request_queue (janus.Queue) output_queue (janus.Queue) │ + ├──────────────────────┼─────────────────────────────────┼────────────────────────┤ + │ ▼ Orchestration Layer │ │ + │ ┌───────────────────────────────────────────────────────────────────────────┐ │ + │ │ Orchestrator [background thread] │ │ + │ │ • _request_handler() │ │ + │ │ - stage_client.add_request_async() & _prewarm_async_chunk_stages() │ │ + │ │ • _orchestration_output_handler() │ │ + │ │ - _process_stage_outputs() -> output_processors[i].process_outputs() │ │ + │ │ - _route_output() & _forward_to_next_stage() │ │ + │ └──────────┬─────────────────────────┬────────────────────────┬─────────────┘ │ + ├─────────────┼─────────────────────────┼────────────────────────┼────────────────┤ + │ │ Communication Layer │ │ + │ ┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐ │ + │ │ StageEngineCoreClient │ │ StageEngineCoreClient │ │ StageDiffusionClient │ │ + │ │ • ZMQ ROUTER / PULL │ │ • ZMQ ROUTER / PULL │ │ • ZMQ ROUTER / PULL │ │ + │ │ • Msgpack codec │ │ • Msgpack codec │ │ • Msgpack codec │ │ + │ └──────────┬────────────┘ └──────────┬────────────┘ └──────────┬────────────┘ │ + │ ▼ ZMQ IPC ▼ ZMQ IPC ▼ ZMQ IPC │ + ├─────────────────────────────────────────────────────────────────────────────────┤ + │ Execution Layer │ + │ ┌───────────────────────┐ ┌───────────────────────┐ ┌───────────────────────┐ │ + │ │ StageCoreProc │ │ StageCoreProc │ │ DiffusionEngine │ │ + │ │ [background process] │ │ [background process] │ │ [background process] │ │ + │ └───────────────────────┘ └───────────────────────┘ └───────────────────────┘ │ + └─────────────────────────────────────────────────────────────────────────────────┘ +``` + +## 2. Execution Flow (Arrow Steps, one generate request) + +```text +[1] App + -> AsyncOmni.generate(prompt, request_id) + +[2] AsyncOmni + -> _final_output_handler() (started on first request) + -> AsyncOmniEngine.add_request(stage_id=0, ...) + +[3] AsyncOmniEngine.add_request + -> (if stage-0 is llm and input is not EngineCoreRequest) + InputProcessor.process_inputs() + OutputProcessor[0].add_request() + -> request_queue.put(add_request_msg) + +[4] Orchestrator._request_handler + -> _handle_add_request(msg) + -> stage_clients[0].add_request_async(...) + +[5] Orchestrator._orchestration_loop (loop) + -> poll stage output + - llm stage: await get_output_async() + - diffusion stage: get_diffusion_output_async() + -> (llm stage) output_processors[i].process_outputs(...) + -> _route_output(...) + -> if finished and not final_stage and non-async-chunk: + _forward_to_next_stage(...) + -> next_stage.add_request_async(...) + -> output_queue.put(output) + +[6] AsyncOmni._final_output_loop (background coroutine) + -> AsyncOmniEngine.try_get_output_async() + -> route by request_id to ClientRequestState.queue + +[7] AsyncOmni._process_orchestrator_results + -> read from ClientRequestState.queue + -> _process_single_result(...) + -> yield OmniRequestOutput + +[8] Exit condition + -> receive result["finished"] == True + -> generate() ends +``` + +## 3. Runtime Sequence (one generate request) + +```mermaid +sequenceDiagram + participant APP as App + participant AO as AsyncOmni + participant ENG as AsyncOmniEngine + participant ORCH as Orchestrator + participant S0 as Stage-0 Client + participant SN as Next Stage Client + + APP->>AO: generate + AO->>AO: start output_handler once + AO->>ENG: add_request(stage_id=0, ...) + ENG->>ENG: input_processor.process_inputs() + ENG->>ORCH: request_queue.put(add_request) + + ORCH->>ORCH: _handle_add_request + ORCH->>S0: add_request_async + + loop poll route forward + ORCH->>S0: get_output_async / get_diffusion_output_async + ORCH->>ORCH: _route_output + alt need forward to next stage + ORCH->>SN: add_request_async + end + ORCH-->>ENG: output_queue.put + end + + AO->>ENG: try_get_output_async + ENG-->>AO: message + AO-->>APP: yield OmniRequestOutput +``` + +## 4. Comparison + +Previous topology (reference): + +```text +┌────────────────────────────────────────────────────────────────────────────┐ +│ Main Process │ +│ ┌──────────────────────┐ ┌────────────────────────────────────────────┐ │ +│ │ generate() │ │ final_output_handler() │ │ +│ └──────────────────────┘ └────────────────────────────────────────────┘ │ +└──────────┬─────────────────────────┬─────────────────────────┬─────────────┘ + mp.Queue (in_q/out_q) mp.Queue (in_q/out_q) mp.Queue (in_q/out_q) + ▼▲ ▼▲ ▼▲ +┌───────────────────────┐ ┌───────────────────────┐ ┌──────────────────────┐ +│ Worker Proc-0 │ │ Worker Proc-1 │ │ Worker Proc-2 │ +│ (Thinker LLM) │ │ (Talker LLM) │ │ (Vocoder) │ +│ ┌────────────────┐ │ │ ┌────────────────┐ │ │ ┌────────────────┐ │ +│ │_stage_worker │ │ │ │_stage_worker │ │ │ │_stage_worker │ │ +│ │_async() │ │ │ │_async() │ │ │ │_async() │ │ +│ └────────────────┘ │ │ └────────────────┘ │ │ └────────────────┘ │ +│ ┌────────────────┐ │ │ ┌────────────────┐ │ │ ┌────────────────┐ │ +│ │output_handler()│ │ │ │output_handler()│ │ │ │output_handler()│ │ +│ └────────────────┘ │ │ └────────────────┘ │ │ └────────────────┘ │ +└──────────┬────────────┘ └──────────┬────────────┘ └──────────┬───────────┘ + ZMQ ▼ ▲ ZMQ ZMQ ▼ ▲ ZMQ ZMQ ▼ ▲ ZMQ +┌──────────────────────┐ ┌──────────────────────┐ ┌──────────────────────┐ +│ EngineCore Proc-0 │ │ EngineCore Proc-1 │ │ EngineCore Proc-2 │ +│ (Thinker) │ │ (Talker) │ │ (Vocoder) │ +└──────────────────────┘ └──────────────────────┘ └──────────────────────┘ +``` + +Current topology: + +```text +┌────────────────────────────────────────────────────────────────────────────┐ +│ Main Process │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ Main Thread │ │ +│ │ ┌──────────────────────┐ ┌─────────────────────────────────────┐ │ │ +│ │ │ generate() │ │ final_output_handler() │ │ │ +│ │ └──────────────────────┘ └─────────────────────────────────────┘ │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ janus.Queue (request_queue) ▼ ▲ janus.Queue (output_queue) │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ Orchestrator Thread │ │ +│ │ ┌──────────────────────┐ ┌──────────────────────────────────────┐ │ │ +│ │ │ _request_handler() │ │ _orchestration_output_handler() │ │ │ +│ │ └──────────────────────┘ └──────────────────────────────────────┘ │ │ +│ │ ┌────────────────────────────────────────────────────────────────┐ │ │ +│ │ │ _orchestration_loop(): poll/process/route outputs for all stages│ │ │ +│ │ └────────────────────────────────────────────────────────────────┘ │ │ +│ └───────┬─────────────────────────┬─────────────────────────┬──────────┘ │ +└──────────┬─────────────────────────┬─────────────────────────┬─────────────┘ + ZMQ ▼ ▲ ZMQ ZMQ ▼ ▲ ZMQ ZMQ ▼ ▲ ZMQ + ┌──────────────────────┐ ┌──────────────────────┐ ┌──────────────────────┐ + │ EngineCore Proc-0 │ │ EngineCore Proc-1 │ │ EngineCore Proc-2 │ + │ (Thinker) │ │ (Talker) │ │ (Vocoder) │ + └──────────────────────┘ └──────────────────────┘ └──────────────────────┘ +``` + + +Test scripts: +```bash +# enter offline inference folder. +cd examples/offline_inference/qwen2_5_omni +python end2end.py --output-dir output_audio --query-type use_mixed_modalities + +cd ../qwen3_omni +python end2end.py --output-dir output_audio --query-type text --async-chunk --enable-stats + +cd ../bagel +python end2end.py --prompts "A cute cat" + +cd ../text_to_image +python text_to_image.py --prompt "a cup of coffee on the table" --output output.png +``` diff --git a/docs/features/custom_pipeline.md b/docs/features/custom_pipeline.md index aee5cfa6be..dc8b645e45 100644 --- a/docs/features/custom_pipeline.md +++ b/docs/features/custom_pipeline.md @@ -97,7 +97,7 @@ outputs = omni.generate( ) # Access custom trajectory data -output = outputs[0].request_output[0] +output = outputs[0].request_output print(f"Trajectory timesteps shape: {output.metrics['trajectory_timesteps'].shape}") print(f"Trajectory latents shape: {output.latents.shape}") ``` diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 6d81971931..9b1182c6b3 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -42,7 +42,7 @@ if __name__ == "__main__": omni = Omni(model="Tongyi-MAI/Z-Image-Turbo") prompt = "a cup of coffee on the table" outputs = omni.generate(prompt) - images = outputs[0].request_output[0].images + images = outputs[0].request_output.images images[0].save("coffee.png") ``` @@ -70,7 +70,7 @@ if __name__ == "__main__": ] omni_outputs = omni.generate(prompts) for i_prompt, prompt_output in enumerate(omni_outputs): - this_request_output = prompt_output.request_output[0] + this_request_output = prompt_output.request_output this_images = this_request_output.images for i_image, image in enumerate(this_images): image.save(f"p{i_prompt}-img{i_image}.jpg") diff --git a/docs/user_guide/diffusion/quantization/gguf.md b/docs/user_guide/diffusion/quantization/gguf.md index 9b9cb6fc77..c3d5eb7eb9 100644 --- a/docs/user_guide/diffusion/quantization/gguf.md +++ b/docs/user_guide/diffusion/quantization/gguf.md @@ -22,10 +22,16 @@ CLI (examples/offline_inference/text_to_image/text_to_image.py) | v -Omni (vllm_omni/entrypoints/omni.py) +Omni (vllm_omni/entrypoints/__init__.py) | v -OmniStage (diffusion) +AsyncOmniEngine + | + v +Orchestrator + | + v +StageDiffusionClient | v DiffusionWorker @@ -64,7 +70,13 @@ _generate_with_async_omni AsyncOmni | v -DiffusionEngine +AsyncOmniEngine + | + v +Orchestrator + | + v +StageDiffusionClient | v OmniRequestOutput diff --git a/docs/user_guide/examples/offline_inference/text_to_image.md b/docs/user_guide/examples/offline_inference/text_to_image.md index 8b6dadc8ef..653973fe42 100644 --- a/docs/user_guide/examples/offline_inference/text_to_image.md +++ b/docs/user_guide/examples/offline_inference/text_to_image.md @@ -19,7 +19,7 @@ if __name__ == "__main__": omni = Omni(model="Qwen/Qwen-Image") prompt = "a cup of coffee on the table" outputs = omni.generate(prompt) - images = outputs[0].request_output[0].images + images = outputs[0].request_output.images images[0].save("coffee.png") ``` @@ -37,7 +37,7 @@ if __name__ == "__main__": ] outputs = omni.generate(prompts) for i, output in enumerate(outputs): - image = output.request_output[0].images[0].save(f"{i}.jpg") + image = output.request_output.images[0].save(f"{i}.jpg") ``` !!! info @@ -72,7 +72,7 @@ if __name__ == "__main__": } ]) for i, output in enumerate(outputs): - image = output.request_output[0].images[0].save(f"{i}.jpg") + image = output.request_output.images[0].save(f"{i}.jpg") ``` ## Local CLI Usage diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index cffe44fc36..dca0907a1e 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -34,7 +34,7 @@ def parse_args(): help="Path to input image for img2img.", ) - # OmniLLM init args + # Omni runtime init args parser.add_argument("--log-stats", action="store_true", default=False) parser.add_argument("--init-sleep-seconds", type=int, default=20) parser.add_argument("--batch-timeout", type=int, default=5) @@ -163,25 +163,17 @@ def main(): omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list)) - for i, req_output in enumerate(omni_outputs): + img_idx = 0 + for req_output in omni_outputs: images = getattr(req_output, "images", None) - if not images and hasattr(req_output, "output"): - if isinstance(req_output.output, list): - images = req_output.output - else: - images = [req_output.output] - - if images: - for j, img in enumerate(images): - img.save(f"output_{i}_{j}.png") - - if hasattr(req_output, "request_output") and req_output.request_output: - for stage_out in req_output.request_output: - if hasattr(stage_out, "images") and stage_out.images: - for k, img in enumerate(stage_out.images): - save_path = f"output_{i}_stage_{getattr(stage_out, 'stage_id', '?')}_{k}.png" - img.save(save_path) - print(f"[Info] Saved stage output image to {save_path}") + if not images: + continue + + for j, img in enumerate(images): + save_path = f"output_{img_idx}_{j}.png" + img.save(save_path) + print(f"[Info] Saved image to {save_path}") + img_idx += 1 print(omni_outputs) diff --git a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py index c962aac805..68ab72b387 100644 --- a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py +++ b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py @@ -167,27 +167,26 @@ def run_e2e(): print(f"Received {len(outputs)} outputs.") for i, output in enumerate(outputs): try: - ro_list = output.request_output or [] - if not ro_list: + ro = output.request_output + if ro is None: print("No request_output found.") continue - for ro in ro_list: - # Multimodal output may be attached to RequestOutput or CompletionOutput. - mm = getattr(ro, "multimodal_output", None) - if not mm and ro.outputs: - mm = getattr(ro.outputs[0], "multimodal_output", None) - - if mm: - print(f"Multimodal output keys: {mm.keys()}") - if "audio" in mm: - audio_out = mm["audio"] - print(f"Generated Audio Shape: {audio_out.shape}") - out_path = f"output_{i}.wav" - sf.write(out_path, audio_out.cpu().numpy().squeeze(), 22050) - print(f"Saved audio to {out_path}") - else: - print("No multimodal output found.") + # Multimodal output may be attached to RequestOutput or CompletionOutput. + mm = getattr(ro, "multimodal_output", None) + if not mm and ro.outputs: + mm = getattr(ro.outputs[0], "multimodal_output", None) + + if mm: + print(f"Multimodal output keys: {mm.keys()}") + if "audio" in mm: + audio_out = mm["audio"] + print(f"Generated Audio Shape: {audio_out.shape}") + out_path = f"output_{i}.wav" + sf.write(out_path, audio_out.cpu().numpy().squeeze(), 22050) + print(f"Saved audio to {out_path}") + else: + print("No multimodal output found.") except Exception as e: print(f"Error inspecting output: {e}") omni.close() diff --git a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py index a9b3c32f81..8ab5e0d9a6 100644 --- a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py +++ b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py @@ -209,7 +209,7 @@ async def main(): if not outputs: raise ValueError("No output produced from omni.generate()") - first_out = outputs[0].request_output[0] + first_out = outputs[0].request_output req_out: OmniRequestOutput = first_out # Verify trajectory data (from custom pipeline) diff --git a/examples/offline_inference/fish_speech/end2end.py b/examples/offline_inference/fish_speech/end2end.py index 9e95f7e5c7..78bbde857f 100644 --- a/examples/offline_inference/fish_speech/end2end.py +++ b/examples/offline_inference/fish_speech/end2end.py @@ -129,8 +129,14 @@ def main(args): t_start = time.perf_counter() for stage_outputs in omni.generate(inputs): - for output in stage_outputs.request_output: - _save_wav(output_dir, output.request_id, output.outputs[0].multimodal_output) + request_output = stage_outputs.request_output + if request_output is None or not request_output.outputs: + continue + _save_wav( + output_dir, + request_output.request_id, + request_output.outputs[0].multimodal_output, + ) t_end = time.perf_counter() logger.info("Total inference time: %.1f ms", (t_end - t_start) * 1000) diff --git a/examples/offline_inference/glm_image/end2end.py b/examples/offline_inference/glm_image/end2end.py index 77de043a4c..a3c1b7dc2b 100644 --- a/examples/offline_inference/glm_image/end2end.py +++ b/examples/offline_inference/glm_image/end2end.py @@ -310,36 +310,35 @@ def main(args: argparse.Namespace) -> None: output_count = 0 for stage_outputs in omni.generate(prompts, sampling_params_list, py_generator=True): + output = stage_outputs.request_output if stage_outputs.final_output_type == "image": - for output in stage_outputs.request_output: - request_id = output.request_id + request_id = output.request_id - # Get generated images - images = output.images if hasattr(output, "images") else [] - if not images and hasattr(output, "multimodal_output"): - images = output.multimodal_output.get("images", []) + # Get generated images + images = output.images if hasattr(output, "images") else [] + if not images and hasattr(output, "multimodal_output"): + images = output.multimodal_output.get("images", []) - # Save each generated image - for idx, img in enumerate(images): - if args.num_prompts == 1 and len(images) == 1: - output_path = args.output - else: - base, ext = os.path.splitext(args.output) - output_path = f"{base}_{request_id}_{idx}{ext}" + # Save each generated image + for idx, img in enumerate(images): + if args.num_prompts == 1 and len(images) == 1: + output_path = args.output + else: + base, ext = os.path.splitext(args.output) + output_path = f"{base}_{request_id}_{idx}{ext}" - if isinstance(img, Image.Image): - save_image(img, output_path) - else: - print(f"Warning: Unexpected image type for request {request_id}: {type(img)}") + if isinstance(img, Image.Image): + save_image(img, output_path) + else: + print(f"Warning: Unexpected image type for request {request_id}: {type(img)}") - output_count += 1 + output_count += 1 elif stage_outputs.final_output_type == "text": # AR stage output (intermediate, for debugging) if args.verbose: - for output in stage_outputs.request_output: - print(f"AR output for request {output.request_id}:") - print(f" Token count: {len(output.outputs[0].token_ids)}") + print(f"AR output for request {output.request_id}:") + print(f" Token count: {len(output.outputs[0].token_ids)}") gen_time = time.time() - gen_start_time print(f"\nGeneration completed in {gen_time:.2f}s") diff --git a/examples/offline_inference/helios/end2end.py b/examples/offline_inference/helios/end2end.py index 450c65cc47..36f5bff161 100644 --- a/examples/offline_inference/helios/end2end.py +++ b/examples/offline_inference/helios/end2end.py @@ -293,12 +293,11 @@ def main(): ) if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output: - if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0: - inner_output = first_item.request_output[0] - if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): - frames = inner_output.images[0] if inner_output.images else None - if frames is None: - raise ValueError("No video frames found in output.") + inner_output = first_item.request_output + if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): + frames = inner_output.images[0] if inner_output.images else None + if frames is None: + raise ValueError("No video frames found in output.") elif hasattr(first_item, "images") and first_item.images: frames = first_item.images else: diff --git a/examples/offline_inference/hunyuan_image3/image_to_text.py b/examples/offline_inference/hunyuan_image3/image_to_text.py index dbae043155..dbe6cdea58 100644 --- a/examples/offline_inference/hunyuan_image3/image_to_text.py +++ b/examples/offline_inference/hunyuan_image3/image_to_text.py @@ -9,7 +9,7 @@ from vllm_omni.entrypoints.omni import Omni """ -The tencent/HunyuanImage-3.0-Instruct base model is built on the Hunyuan v1 architecture, specifically the tencent/Hunyuan-A13B-Instruct model. It utilizes two tokenizer delimiter templates: +The tencent/HunyuanImage-3.0-Instruct base model uses the tencent/Hunyuan-A13B-Instruct backbone. It utilizes two tokenizer delimiter templates: 1) Pretrained template (default for gen_text mode), which concatenates system, image tokens, and user question WITHOUT role delimiters: @@ -70,8 +70,8 @@ def main(args: argparse.Namespace) -> None: prompts = [prompt_dict] omni_outputs = omni.generate(prompts=prompts) - prompt_text = omni_outputs[0].request_output[0].prompt - generated_text = omni_outputs[0].request_output[0].outputs[0].text + prompt_text = omni_outputs[0].request_output.prompt + generated_text = omni_outputs[0].request_output.outputs[0].text print(f"Prompt: {prompt_text}") print(f"Text: {generated_text}") diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 83ac5af05b..c15478b718 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -463,12 +463,12 @@ def main(): raise ValueError("No output generated from omni.generate()") # Extract images from OmniRequestOutput - # omni.generate() returns list[OmniRequestOutput], extract images from request_output[0].images + # omni.generate() returns list[OmniRequestOutput], extract images from request_output.images first_output = outputs[0] if not hasattr(first_output, "request_output") or not first_output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - req_out = first_output.request_output[0] + req_out = first_output.request_output if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): raise ValueError("Invalid request_output structure or missing 'images' key") diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 7b0e1bb2ef..6d4dc06acf 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -338,8 +338,6 @@ def main(): audio = frames.multimodal_output["audio"] if frames.is_pipeline_output and frames.request_output is not None: inner_output = frames.request_output - if isinstance(inner_output, list): - inner_output = inner_output[0] if inner_output else None if isinstance(inner_output, OmniRequestOutput): if inner_output.multimodal_output and "audio" in inner_output.multimodal_output: audio = inner_output.multimodal_output["audio"] diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py index e29f907f20..830557aa07 100644 --- a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py +++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py @@ -99,14 +99,12 @@ def main() -> None: lines: list[str] = [] for stage_outputs in outputs: - req_outputs = getattr(stage_outputs, "request_output", stage_outputs) - req_outputs = req_outputs if isinstance(req_outputs, list) else [req_outputs] - for ro in req_outputs: - text = ro.outputs[0].text if getattr(ro, "outputs", None) else str(ro) - lines.append(f"request_id: {getattr(ro, 'request_id', 'unknown')}\n") - lines.append("answer:\n") - lines.append(text.strip() + "\n") - lines.append("\n") + ro = getattr(stage_outputs, "request_output", stage_outputs) + text = ro.outputs[0].text if getattr(ro, "outputs", None) else str(ro) + lines.append(f"request_id: {getattr(ro, 'request_id', 'unknown')}\n") + lines.append("answer:\n") + lines.append(text.strip() + "\n") + lines.append("\n") print("\n".join(lines)) diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py index 1f10fb2882..a4c41fee1f 100644 --- a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py +++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py @@ -148,19 +148,16 @@ def _collect_images(outputs: list) -> list[torch.Tensor]: """Extract all image tensors produced by the final (DiT) stage.""" images: list[torch.Tensor] = [] for out in outputs: - ro_list = getattr(out, "request_output", out) - if not isinstance(ro_list, list): - ro_list = [ro_list] - for ro_item in ro_list: - for completion in getattr(ro_item, "outputs", None) or []: - mm = getattr(completion, "multimodal_output", None) - if not isinstance(mm, dict) or "image" not in mm: - raise RuntimeError(f"Missing image in multimodal output: {mm}") - payload = mm["image"] - for tensor in payload if isinstance(payload, list) else [payload]: - if not isinstance(tensor, torch.Tensor): - raise TypeError(f"Expected image tensor, got {type(tensor)}") - images.append(tensor) + ro_item = getattr(out, "request_output", out) + for completion in getattr(ro_item, "outputs", None) or []: + mm = getattr(completion, "multimodal_output", None) + if not isinstance(mm, dict) or "image" not in mm: + raise RuntimeError(f"Missing image in multimodal output: {mm}") + payload = mm["image"] + for tensor in payload if isinstance(payload, list) else [payload]: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Expected image tensor, got {type(tensor)}") + images.append(tensor) return images diff --git a/examples/offline_inference/mimo_audio/end2end.py b/examples/offline_inference/mimo_audio/end2end.py index 67b9f14c06..ae044d2e8a 100644 --- a/examples/offline_inference/mimo_audio/end2end.py +++ b/examples/offline_inference/mimo_audio/end2end.py @@ -180,11 +180,11 @@ def main(args): # Get the query function and call it with appropriate parameters query_func = query_map[args.query_type] - omni_llm = Omni( + omni = Omni( model=model_name, stage_configs_path=args.stage_configs_path, log_stats=args.enable_stats, - log_file=("omni_llm_pipeline.log" if args.enable_stats else None), + log_file=("omni_pipeline.log" if args.enable_stats else None), init_sleep_seconds=args.init_sleep_seconds, batch_timeout=args.batch_timeout, init_timeout=args.init_timeout, @@ -290,7 +290,7 @@ def main(args): prompts = [copy.deepcopy(query_result) for _ in range(args.num_prompts)] print("prompts", prompts) - omni_outputs = omni_llm.generate(prompts, sampling_params_list) + omni_outputs = omni.generate(prompts, sampling_params_list) output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav if args.query_type is not None: @@ -298,45 +298,44 @@ def main(args): os.makedirs(output_dir, exist_ok=True) for stage_outputs in omni_outputs: + output = stage_outputs.request_output if stage_outputs.final_output_type == "text": - for output in stage_outputs.request_output: - request_id = output.request_id - text_output = output.outputs[0].text - # Save aligned text file per request - prompt_text = output.prompt - out_txt = os.path.join(output_dir, f"{request_id}.txt") - lines = [] - lines.append("Prompt:\n") - lines.append(str(prompt_text) + "\n") - lines.append("vllm_text_output:\n") - lines.append(str(text_output).strip() + "\n") - try: - with open(out_txt, "w", encoding="utf-8") as f: - print("lines", lines) - f.writelines(lines) - except Exception as e: - print(f"[Warn] Failed writing text file {out_txt}: {e}") - print(f"Request ID: {request_id}, Text saved to {out_txt}\n") + request_id = output.request_id + text_output = output.outputs[0].text + # Save aligned text file per request + prompt_text = output.prompt + out_txt = os.path.join(output_dir, f"{request_id}.txt") + lines = [] + lines.append("Prompt:\n") + lines.append(str(prompt_text) + "\n") + lines.append("vllm_text_output:\n") + lines.append(str(text_output).strip() + "\n") + try: + with open(out_txt, "w", encoding="utf-8") as f: + print("lines", lines) + f.writelines(lines) + except Exception as e: + print(f"[Warn] Failed writing text file {out_txt}: {e}") + print(f"Request ID: {request_id}, Text saved to {out_txt}\n") elif stage_outputs.final_output_type == "audio": - for output in stage_outputs.request_output: - request_id = output.request_id - audio_tensor = output.outputs[0].multimodal_output.get("audio") + request_id = output.request_id + audio_tensor = output.outputs[0].multimodal_output.get("audio") - if audio_tensor is None: - continue + if audio_tensor is None: + continue - output_wav = os.path.join(output_dir, f"{request_id}.wav") + output_wav = os.path.join(output_dir, f"{request_id}.wav") - # Convert to numpy array and ensure correct format - audio_numpy = audio_tensor.float().detach().cpu().numpy() + # Convert to numpy array and ensure correct format + audio_numpy = audio_tensor.float().detach().cpu().numpy() - # Ensure audio is 1D (flatten if needed) - if audio_numpy.ndim > 1: - audio_numpy = audio_numpy.flatten() + # Ensure audio is 1D (flatten if needed) + if audio_numpy.ndim > 1: + audio_numpy = audio_numpy.flatten() - # Save audio file with explicit WAV format - sf.write(output_wav, audio_numpy, samplerate=24000, format="WAV") - print(f"Request ID: {request_id}, Audio saved to {output_wav}") + # Save audio file with explicit WAV format + sf.write(output_wav, audio_numpy, samplerate=24000, format="WAV") + print(f"Request ID: {request_id}, Audio saved to {output_wav}") def parse_args(): diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index 7cd8e18737..7bba599830 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -320,7 +320,7 @@ def main(args): query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate) else: query_result = query_func() - omni_llm = Omni( + omni = Omni( model=model_name, log_stats=args.log_stats, stage_init_timeout=args.stage_init_timeout, @@ -378,9 +378,11 @@ def main(args): prompt["modalities"] = output_modalities profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) - if profiler_enabled: - omni_llm.start_profile(stages=[0]) - omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator) + if profiler_enabled and hasattr(omni, "start_profile"): + omni.start_profile(stages=[0]) + elif profiler_enabled: + print("[Warn] VLLM_TORCH_PROFILER_DIR is set, but current engine does not support profiler controls.") + omni_generator = omni.generate(prompts, sampling_params_list, py_generator=args.py_generator) # Determine output directory: prefer --output-dir; fallback to --output-wav output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav @@ -389,43 +391,42 @@ def main(args): total_requests = len(prompts) processed_count = 0 for stage_outputs in omni_generator: + output = stage_outputs.request_output if stage_outputs.final_output_type == "text": - for output in stage_outputs.request_output: - request_id = output.request_id - text_output = output.outputs[0].text - # Save aligned text file per request - prompt_text = output.prompt - out_txt = os.path.join(output_dir, f"{request_id}.txt") - lines = [] - lines.append("Prompt:\n") - lines.append(str(prompt_text) + "\n") - lines.append("vllm_text_output:\n") - lines.append(str(text_output).strip() + "\n") - try: - with open(out_txt, "w", encoding="utf-8") as f: - f.writelines(lines) - except Exception as e: - print(f"[Warn] Failed writing text file {out_txt}: {e}") - print(f"Request ID: {request_id}, Text saved to {out_txt}") + request_id = output.request_id + text_output = output.outputs[0].text + # Save aligned text file per request + prompt_text = output.prompt + out_txt = os.path.join(output_dir, f"{request_id}.txt") + lines = [] + lines.append("Prompt:\n") + lines.append(str(prompt_text) + "\n") + lines.append("vllm_text_output:\n") + lines.append(str(text_output).strip() + "\n") + try: + with open(out_txt, "w", encoding="utf-8") as f: + f.writelines(lines) + except Exception as e: + print(f"[Warn] Failed writing text file {out_txt}: {e}") + print(f"Request ID: {request_id}, Text saved to {out_txt}") elif stage_outputs.final_output_type == "audio": - for output in stage_outputs.request_output: - request_id = output.request_id - audio_tensor = output.outputs[0].multimodal_output["audio"] - output_wav = os.path.join(output_dir, f"output_{request_id}.wav") - sf.write(output_wav, audio_tensor.detach().cpu().numpy(), samplerate=24000) - print(f"Request ID: {request_id}, Saved audio to {output_wav}") - - processed_count += len(stage_outputs.request_output) - if profiler_enabled and processed_count >= total_requests: + request_id = output.request_id + audio_tensor = output.outputs[0].multimodal_output["audio"] + output_wav = os.path.join(output_dir, f"output_{request_id}.wav") + sf.write(output_wav, audio_tensor.detach().cpu().numpy(), samplerate=24000) + print(f"Request ID: {request_id}, Saved audio to {output_wav}") + + processed_count += 1 + if profiler_enabled and hasattr(omni, "stop_profile") and processed_count >= total_requests: print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") # Stop the profiler while workers are still alive - omni_llm.stop_profile() + omni.stop_profile() print("[Info] Waiting 30s for workers to write massive trace files to disk...") time.sleep(30) print("[Info] Trace export wait finished.") - omni_llm.close() + omni.close() def parse_args(): diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 509886a02b..0bf143e0ce 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -294,7 +294,7 @@ def main(args): else: query_result = query_func() - omni_llm = Omni( + omni = Omni( model=model_name, stage_configs_path=args.stage_configs_path, log_stats=args.log_stats, @@ -354,8 +354,8 @@ def main(args): profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) if profiler_enabled: - omni_llm.start_profile(stages=[0]) - omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator) + omni.start_profile(stages=[0]) + omni_generator = omni.generate(prompts, sampling_params_list, py_generator=args.py_generator) # Determine output directory: prefer --output-dir; fallback to --output-wav output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav os.makedirs(output_dir, exist_ok=True) @@ -366,51 +366,50 @@ def main(args): print(f"query type: {args.query_type}") for stage_outputs in omni_generator: + output = stage_outputs.request_output if stage_outputs.final_output_type == "text": - for output in stage_outputs.request_output: - request_id = output.request_id - text_output = output.outputs[0].text - # Save aligned text file per request - prompt_text = output.prompt - out_txt = os.path.join(output_dir, f"{request_id}.txt") - lines = [] - lines.append("Prompt:\n") - lines.append(str(prompt_text) + "\n") - lines.append("vllm_text_output:\n") - lines.append(str(text_output).strip() + "\n") - try: - with open(out_txt, "w", encoding="utf-8") as f: - f.writelines(lines) - except Exception as e: - print(f"[Warn] Failed writing text file {out_txt}: {e}") - print(f"Request ID: {request_id}, Text saved to {out_txt}") + request_id = output.request_id + text_output = output.outputs[0].text + # Save aligned text file per request + prompt_text = output.prompt + out_txt = os.path.join(output_dir, f"{request_id}.txt") + lines = [] + lines.append("Prompt:\n") + lines.append(str(prompt_text) + "\n") + lines.append("vllm_text_output:\n") + lines.append(str(text_output).strip() + "\n") + try: + with open(out_txt, "w", encoding="utf-8") as f: + f.writelines(lines) + except Exception as e: + print(f"[Warn] Failed writing text file {out_txt}: {e}") + print(f"Request ID: {request_id}, Text saved to {out_txt}") elif stage_outputs.final_output_type == "audio": - for output in stage_outputs.request_output: - request_id = output.request_id - audio_tensor = output.outputs[0].multimodal_output["audio"] - output_wav = os.path.join(output_dir, f"output_{request_id}.wav") + request_id = output.request_id + audio_tensor = output.outputs[0].multimodal_output["audio"] + output_wav = os.path.join(output_dir, f"output_{request_id}.wav") - # Convert to numpy array and ensure correct format - audio_numpy = audio_tensor.float().detach().cpu().numpy() + # Convert to numpy array and ensure correct format + audio_numpy = audio_tensor.float().detach().cpu().numpy() - # Ensure audio is 1D (flatten if needed) - if audio_numpy.ndim > 1: - audio_numpy = audio_numpy.flatten() + # Ensure audio is 1D (flatten if needed) + if audio_numpy.ndim > 1: + audio_numpy = audio_numpy.flatten() - # Save audio file with explicit WAV format - sf.write(output_wav, audio_numpy, samplerate=24000, format="WAV") - print(f"Request ID: {request_id}, Saved audio to {output_wav}") + # Save audio file with explicit WAV format + sf.write(output_wav, audio_numpy, samplerate=24000, format="WAV") + print(f"Request ID: {request_id}, Saved audio to {output_wav}") - processed_count += len(stage_outputs.request_output) + processed_count += 1 if profiler_enabled and processed_count >= total_requests: print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") # Stop the profiler while workers are still alive - omni_llm.stop_profile() + omni.stop_profile() print("[Info] Waiting 30s for workers to write trace files to disk...") time.sleep(30) print("[Info] Trace export wait time finished.") - omni_llm.close() + omni.close() def parse_args(): diff --git a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py index 965595578f..8adbae9eb6 100644 --- a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py +++ b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py @@ -231,52 +231,47 @@ async def run_single_request( sampling_params_list=sampling_params_list, output_modalities=output_modalities, ): - if not isinstance(omni_output.request_output, list): - outputs_list = [omni_output.request_output] - else: - outputs_list = omni_output.request_output - - for output in outputs_list: - if omni_output.final_output_type == "text": - if stage_0_first_output_ts is None: - stage_0_first_output_ts = time.perf_counter() - text_output = output.outputs[0].text - if output.finished: - text_parts.append(text_output) - elif omni_output.final_output_type == "audio": - mm_out = output.outputs[0].multimodal_output - if mm_out and "audio" in mm_out: - if first_audio_ts is None: - first_audio_ts = time.perf_counter() - if audio_sr is None and "sr" in mm_out: - sr_val = mm_out["sr"] - audio_sr = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) - samplerate = audio_sr - audio_data = mm_out["audio"] - if isinstance(audio_data, list): - new_chunks = audio_data[audio_list_consumed:] - audio_list_consumed = len(audio_data) - elif isinstance(audio_data, torch.Tensor): - new_chunks = [audio_data] - audio_last_tensor = audio_data - else: - new_chunks = [] - - if stream_audio_to_disk and new_chunks: - if sf_writer is None: - sf_writer = sf.SoundFile( - wav_file, - mode="w", - samplerate=samplerate, - channels=1, - subtype="FLOAT", - ) - for chunk in new_chunks: - chunk_np = chunk.float().detach().cpu().numpy().flatten() - sf_writer.write(chunk_np) - audio_samples_written += len(chunk_np) - else: - audio_chunks.extend(new_chunks) + output = omni_output.request_output + if omni_output.final_output_type == "text": + if stage_0_first_output_ts is None: + stage_0_first_output_ts = time.perf_counter() + text_output = output.outputs[0].text + if output.finished: + text_parts.append(text_output) + elif omni_output.final_output_type == "audio": + mm_out = output.outputs[0].multimodal_output + if mm_out and "audio" in mm_out: + if first_audio_ts is None: + first_audio_ts = time.perf_counter() + if audio_sr is None and "sr" in mm_out: + sr_val = mm_out["sr"] + audio_sr = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) + samplerate = audio_sr + audio_data = mm_out["audio"] + if isinstance(audio_data, list): + new_chunks = audio_data[audio_list_consumed:] + audio_list_consumed = len(audio_data) + elif isinstance(audio_data, torch.Tensor): + new_chunks = [audio_data] + audio_last_tensor = audio_data + else: + new_chunks = [] + + if stream_audio_to_disk and new_chunks: + if sf_writer is None: + sf_writer = sf.SoundFile( + wav_file, + mode="w", + samplerate=samplerate, + channels=1, + subtype="FLOAT", + ) + for chunk in new_chunks: + chunk_np = chunk.float().detach().cpu().numpy().flatten() + sf_writer.write(chunk_np) + audio_samples_written += len(chunk_np) + else: + audio_chunks.extend(new_chunks) finally: if sf_writer is not None: sf_writer.close() diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index f60365032f..901418c39b 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -51,16 +51,68 @@ def _estimate_prompt_len( tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left") cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True) - _cache[model_name] = (tok, getattr(cfg, "talker_config", None)) - tok, tcfg = _cache[model_name] + # Load speech tokenizer (codec encoder) for exact ref_code_len. + speech_tok = None + try: + import os + + from transformers.utils import cached_file + + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_tokenizer import Qwen3TTSTokenizer + + st_cfg_path = cached_file(model_name, "speech_tokenizer/config.json") + if st_cfg_path: + speech_tok = Qwen3TTSTokenizer.from_pretrained( + os.path.dirname(st_cfg_path), torch_dtype=torch.bfloat16 + ) + logger.info("Loaded speech tokenizer for exact ref_code_len estimation") + except Exception as e: + logger.debug("Could not load speech tokenizer: %s", e) + + _cache[model_name] = (tok, getattr(cfg, "talker_config", None), speech_tok) + + tok, tcfg, speech_tok = _cache[model_name] task_type = (additional_information.get("task_type") or ["CustomVoice"])[0] + + def _estimate_ref_code_len(ref_audio: object) -> int | None: + """Encode ref_audio with the actual codec to get exact frame count.""" + if not isinstance(ref_audio, (str, list)): + return None + audio_path = ref_audio[0] if isinstance(ref_audio, list) else ref_audio + if not isinstance(audio_path, str) or not audio_path.strip(): + return None + try: + from vllm.multimodal.media import MediaConnector + + connector = MediaConnector(allowed_local_media_path="/") + audio, sr = connector.fetch_audio(audio_path) + import numpy as np + + wav_np = np.asarray(audio, dtype=np.float32) + + if speech_tok is not None: + enc = speech_tok.encode(wav_np, sr=int(sr), return_dict=True) + ref_code = getattr(enc, "audio_codes", None) + if isinstance(ref_code, list): + ref_code = ref_code[0] if ref_code else None + if ref_code is not None and hasattr(ref_code, "shape"): + shape = ref_code.shape + return int(shape[0]) if len(shape) == 2 else int(shape[1]) if len(shape) == 3 else None + + # Fallback: estimate from duration + codec_hz = getattr(tcfg, "codec_frame_rate", None) or 12 + return int(len(audio) / sr * codec_hz) + except Exception: + return None + return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information( additional_information=additional_information, task_type=task_type, tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"], codec_language_id=getattr(tcfg, "codec_language_id", None), spk_is_dialect=getattr(tcfg, "spk_is_dialect", None), + estimate_ref_code_len=_estimate_ref_code_len, ) except Exception as exc: logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc) @@ -325,8 +377,8 @@ def main(args): for batch_start in range(0, len(inputs), batch_size): batch = inputs[batch_start : batch_start + batch_size] for stage_outputs in omni.generate(batch): - for output in stage_outputs.request_output: - _save_wav(output_dir, output.request_id, output.outputs[0].multimodal_output) + output = stage_outputs.request_output + _save_wav(output_dir, output.request_id, output.outputs[0].multimodal_output) async def main_streaming(args): diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py index d8f4e7829f..68aeee265f 100644 --- a/examples/offline_inference/text_to_audio/text_to_audio.py +++ b/examples/offline_inference/text_to_audio/text_to_audio.py @@ -177,7 +177,7 @@ def main(): output = outputs[0] if not hasattr(output, "request_output") or not output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - request_output = output.request_output[0] + request_output = output.request_output if not hasattr(request_output, "multimodal_output"): raise ValueError("No multimodal_output found in request_output") diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index 2513627e85..0de89c753c 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -16,7 +16,7 @@ if __name__ == "__main__": omni = Omni(model="Qwen/Qwen-Image") prompt = "a cup of coffee on the table" outputs = omni.generate(prompt) - images = outputs[0].request_output[0].images + images = outputs[0].request_output.images images[0].save("coffee.png") ``` @@ -34,7 +34,7 @@ if __name__ == "__main__": ] outputs = omni.generate(prompts) for i, output in enumerate(outputs): - image = output.request_output[0].images[0].save(f"{i}.jpg") + image = output.request_output.images[0].save(f"{i}.jpg") ``` !!! info @@ -69,7 +69,7 @@ if __name__ == "__main__": } ]) for i, output in enumerate(outputs): - image = output.request_output[0].images[0].save(f"{i}.jpg") + image = output.request_output.images[0].save(f"{i}.jpg") ``` ## Local CLI Usage diff --git a/examples/offline_inference/text_to_image/gradio_demo.py b/examples/offline_inference/text_to_image/gradio_demo.py index bde6bb0e75..928325cd69 100644 --- a/examples/offline_inference/text_to_image/gradio_demo.py +++ b/examples/offline_inference/text_to_image/gradio_demo.py @@ -112,7 +112,7 @@ def run_inference( ) images_outputs = [] for output in outputs: - req_out = output.request_output[0] + req_out = output.request_output if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): raise ValueError("Invalid request_output structure or missing 'images' key") images = req_out.images diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 6f98403b93..37c2962854 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -14,7 +14,6 @@ from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.lora.request import LoRARequest from vllm_omni.lora.utils import stable_lora_int_id -from vllm_omni.outputs import OmniRequestOutput from vllm_omni.platforms import current_omni_platform @@ -400,20 +399,18 @@ def main(): else: print("[Profiler] No valid profiling data returned.") - # Extract images from OmniRequestOutput - # omni.generate() returns list[OmniRequestOutput], extract images from the first output + # omni.generate() returns list[OmniRequestOutput] if not outputs or len(outputs) == 0: raise ValueError("No output generated from omni.generate()") logger.info(f"Outputs: {outputs}") - # Extract images from request_output[0]['images'] first_output = outputs[0] if not hasattr(first_output, "request_output") or not first_output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - req_out = first_output.request_output[0] - if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): - raise ValueError("Invalid request_output structure or missing 'images' key") + req_out = first_output.request_output + if not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images'.") images = req_out.images if not images: diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index d7274e22bc..148652a13b 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -240,8 +240,6 @@ def main(): audio = frames.multimodal_output["audio"] if frames.is_pipeline_output and frames.request_output is not None: inner_output = frames.request_output - if isinstance(inner_output, list): - inner_output = inner_output[0] if inner_output else None if isinstance(inner_output, OmniRequestOutput): if inner_output.multimodal_output and "audio" in inner_output.multimodal_output: audio = inner_output.multimodal_output["audio"] diff --git a/requirements/common.txt b/requirements/common.txt index fc39c6c14f..1b65974fec 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -17,3 +17,4 @@ einops>=0.8.1 prettytable>=3.8.0 aenum==3.1.16 pyzmq>=25.0.0 +janus>=1.0.0 diff --git a/tests/comfyui/test_comfyui_integration.py b/tests/comfyui/test_comfyui_integration.py index 24f16d5e9b..f6ce82f9b2 100644 --- a/tests/comfyui/test_comfyui_integration.py +++ b/tests/comfyui/test_comfyui_integration.py @@ -11,6 +11,7 @@ import traceback from collections.abc import Iterable, Sequence from enum import StrEnum, auto +from types import SimpleNamespace from typing import Any, NamedTuple from unittest.mock import AsyncMock, MagicMock, patch @@ -31,6 +32,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm_omni.entrypoints.async_omni import AsyncOmni as RealAsyncOmni from vllm_omni.entrypoints.cli.serve import OmniServeCommand from vllm_omni.inputs.data import OmniSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -43,7 +45,7 @@ class ServerCase(NamedTuple): served_model: str stage_list: list - stage_configs: list[dict] + stage_configs: list[Any] outputs: list[OmniRequestOutput] @@ -262,6 +264,40 @@ def _assert_model_param_values(received: OmniSamplingParams, expected: dict): ) +def _make_stage_config( + stage_type: str, + *, + is_comprehension: bool = False, + model_stage: str | None = None, +): + engine_args = SimpleNamespace() + if model_stage is not None: + engine_args.model_stage = model_stage + return SimpleNamespace( + stage_type=stage_type, + is_comprehension=is_comprehension, + engine_args=engine_args, + ) + + +def _stage_type(stage: Any) -> str | None: + return getattr(stage, "stage_type", None) + + +def _build_output_modalities(stage_configs: list[Any]) -> list[str]: + modalities: list[str] = [] + for stage in stage_configs: + final_output = getattr(stage, "final_output", False) + final_output_type = getattr(stage, "final_output_type", None) + if final_output and isinstance(final_output_type, str): + modalities.append(final_output_type) + if modalities: + return modalities + if any(_stage_type(stage) == "diffusion" for stage in stage_configs): + return ["image"] + return ["text", "audio"] + + def _build_mock_outputs(outputs: Iterable[OmniRequestOutput], sampling_case: SamplingCase, server_case: ServerCase): async def _mock_generate(*args, **kwargs): received_sampling_params_list: Sequence[OmniSamplingParams] | None = ( @@ -367,18 +403,28 @@ async def _mock_preprocess_chat(self, *args, **kwargs): new=_mock_preprocess_chat, ), ): - mock_instance = AsyncMock() + mock_instance = AsyncMock(spec=RealAsyncOmni) mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case) mock_instance.stage_list = server_case.stage_list mock_instance.stage_configs = server_case.stage_configs + mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs) mock_instance.default_sampling_params_list = [ - SamplingParams() if stage.get("stage_type") != "diffusion" else MagicMock() + SamplingParams() if _stage_type(stage) != "diffusion" else MagicMock() for stage in server_case.stage_configs ] mock_instance.errored = False mock_instance.dead_error = RuntimeError("Mock engine error") - mock_instance.model_config = MagicMock(max_model_len=4096, io_processor_plugin=None) + mock_instance.model_config = MagicMock( + max_model_len=4096, + io_processor_plugin=None, + allowed_local_media_path=None, + allowed_media_domains=None, + ) + # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes. + mock_instance.model_config.hf_config = MagicMock() + mock_instance.model_config.hf_config.talker_config = MagicMock() + mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0} mock_instance.io_processor = MagicMock() mock_instance.input_processor = MagicMock() mock_instance.shutdown = MagicMock() @@ -444,7 +490,7 @@ def run_server(): ServerCase( served_model="Tongyi-MAI/Z-Image-Turbo", stage_list=["diffusion"], - stage_configs=[{"stage_type": "diffusion"}], + stage_configs=[_make_stage_config("diffusion")], outputs=[_build_diffusion_image_output_for_images_endpoint()], ), "Tongyi-MAI/Z-Image-Turbo", @@ -455,7 +501,7 @@ def run_server(): ServerCase( served_model="ByteDance-Seed/BAGEL-7B-MoT", stage_list=["diffusion"], - stage_configs=[{"stage_type": "diffusion"}], + stage_configs=[_make_stage_config("diffusion")], outputs=[_build_diffusion_image_output_for_chat_endpoint()], ), "ByteDance-Seed/BAGEL-7B-MoT", @@ -466,7 +512,7 @@ def run_server(): ServerCase( served_model="Qwen/Qwen-Image-Edit", stage_list=["diffusion"], - stage_configs=[{"stage_type": "diffusion"}], + stage_configs=[_make_stage_config("diffusion")], outputs=[_build_diffusion_image_output_for_images_endpoint()], ), "Qwen/Qwen-Image-Edit", @@ -477,7 +523,7 @@ def run_server(): ServerCase( served_model="ByteDance-Seed/BAGEL-7B-MoT", stage_list=["diffusion"], - stage_configs=[{"stage_type": "diffusion"}], + stage_configs=[_make_stage_config("diffusion")], outputs=[_build_diffusion_image_output_for_chat_endpoint()], ), "ByteDance-Seed/BAGEL-7B-MoT", @@ -542,9 +588,9 @@ async def test_image_generation_node(api_server: str, model: str, image_input: b MagicMock(is_comprehension=False, model_stage="llm"), ], stage_configs=[ - {"stage_type": "llm"}, - {"stage_type": "llm"}, - {"stage_type": "llm"}, + _make_stage_config("llm", is_comprehension=True, model_stage="thinker"), + _make_stage_config("llm", is_comprehension=False, model_stage="talker"), + _make_stage_config("llm", is_comprehension=False, model_stage="code2wav"), ], outputs=[_build_audio_chat_output(), _build_text_output("Understanding response")], ), @@ -599,7 +645,7 @@ async def test_understanding_node(api_server: str, sampling_case: SamplingCase): ServerCase( served_model="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", stage_list=["llm"], - stage_configs=[{"stage_type": "llm"}], + stage_configs=[_make_stage_config("llm", model_stage="qwen3_tts")], outputs=[_build_audio_speech_output()], ), VLLMOmniTTS, @@ -617,7 +663,7 @@ async def test_understanding_node(api_server: str, sampling_case: SamplingCase): ServerCase( served_model="Qwen/Qwen3-TTS-12Hz-1.7B-Base", stage_list=["llm"], - stage_configs=[{"stage_type": "llm"}], + stage_configs=[_make_stage_config("llm", model_stage="qwen3_tts")], outputs=[_build_audio_speech_output()], ), VLLMOmniVoiceClone, diff --git a/tests/conftest.py b/tests/conftest.py index 0a3c350d75..2460cfd5bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1733,7 +1733,9 @@ def get_default_sampling_params_list(self) -> list[OmniSamplingParams]: Returns: List of SamplingParams with default decoding for each stage """ - return [st.default_sampling_params for st in self.omni.stage_list] + if not hasattr(self.omni, "default_sampling_params_list"): + raise AttributeError("Omni.default_sampling_params_list is not available") + return list(self.omni.default_sampling_params_list) def get_omni_inputs( self, @@ -1974,9 +1976,9 @@ def _process_output(self, outputs: list[Any]) -> OmniResponse: audio_content = None for stage_output in outputs: if getattr(stage_output, "final_output_type", None) == "text": - text_content = stage_output.request_output[0].outputs[0].text + text_content = stage_output.request_output.outputs[0].text if getattr(stage_output, "final_output_type", None) == "audio": - audio_content = stage_output.request_output[0].outputs[0].multimodal_output["audio"] + audio_content = stage_output.request_output.outputs[0].multimodal_output["audio"] result.audio_content = audio_content result.text_content = text_content diff --git a/tests/diffusion/distributed/test_distributed_vae_executor.py b/tests/diffusion/distributed/test_distributed_vae_executor.py index b53014f39d..42e9f3300b 100644 --- a/tests/diffusion/distributed/test_distributed_vae_executor.py +++ b/tests/diffusion/distributed/test_distributed_vae_executor.py @@ -70,6 +70,7 @@ def mock_dist(): patch.object(dist, "get_world_size", return_value=2), patch.object(dist, "get_rank", return_value=0), patch.object(dist, "is_initialized", return_value=True), + patch.object(dist, "all_reduce", return_value=None), patch.object(dist, "gather", return_value=None), patch.object(dist, "broadcast", return_value=None), ): diff --git a/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml index 87707e9195..acca96e531 100644 --- a/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml +++ b/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml @@ -1,4 +1,4 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. # This config is optimized for CI e2e tests. stage_args: diff --git a/tests/e2e/offline_inference/test_bagel_img2img.py b/tests/e2e/offline_inference/test_bagel_img2img.py index eef0b7d6cf..da9df2778f 100644 --- a/tests/e2e/offline_inference/test_bagel_img2img.py +++ b/tests/e2e/offline_inference/test_bagel_img2img.py @@ -98,9 +98,9 @@ def _extract_generated_image(omni_outputs: list) -> Image.Image | None: if images := getattr(req_output, "images", None): return images[0] if hasattr(req_output, "request_output") and req_output.request_output: - for stage_out in req_output.request_output: - if hasattr(stage_out, "images") and stage_out.images: - return stage_out.images[0] + stage_out = req_output.request_output + if hasattr(stage_out, "images") and stage_out.images: + return stage_out.images[0] return None diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py index 360d49bb1b..7990ac980e 100644 --- a/tests/e2e/offline_inference/test_bagel_text2img.py +++ b/tests/e2e/offline_inference/test_bagel_text2img.py @@ -14,6 +14,9 @@ """ import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" import signal import socket import subprocess @@ -96,9 +99,9 @@ def _extract_generated_image(omni_outputs: list) -> Image.Image | None: if images := getattr(req_output, "images", None): return images[0] if hasattr(req_output, "request_output") and req_output.request_output: - for stage_out in req_output.request_output: - if hasattr(stage_out, "images") and stage_out.images: - return stage_out.images[0] + stage_out = req_output.request_output + if hasattr(stage_out, "images") and stage_out.images: + return stage_out.images[0] return None diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py index 5c754c9a43..0e31413dc0 100644 --- a/tests/e2e/offline_inference/test_cache_dit.py +++ b/tests/e2e/offline_inference/test_cache_dit.py @@ -72,13 +72,13 @@ def test_cache_dit(model_name: str): num_outputs_per_prompt=1, # Single output for speed ), ) - # Extract images from request_output[0]['images'] + # Extract images from request_output['images'] first_output = outputs[0] assert first_output.final_output_type == "image" if not hasattr(first_output, "request_output") or not first_output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - req_out = first_output.request_output[0] + req_out = first_output.request_output if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): raise ValueError("Invalid request_output structure or missing 'images' key") diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py index a95ab43da5..aacc016aa9 100644 --- a/tests/e2e/offline_inference/test_diffusion_lora.py +++ b/tests/e2e/offline_inference/test_diffusion_lora.py @@ -36,7 +36,7 @@ def _extract_images(outputs: list[OmniRequestOutput]): if not hasattr(first_output, "request_output") or not first_output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - req_out = first_output.request_output[0] + req_out = first_output.request_output if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): raise ValueError("Invalid request_output structure or missing 'images' key") return req_out.images diff --git a/tests/e2e/offline_inference/test_expert_parallel.py b/tests/e2e/offline_inference/test_expert_parallel.py index 61f93e78fe..ba126986ec 100644 --- a/tests/e2e/offline_inference/test_expert_parallel.py +++ b/tests/e2e/offline_inference/test_expert_parallel.py @@ -129,7 +129,7 @@ def _run_inference( elapsed_ms = (time.time() - start) * 1000 return InferenceResult( - images=outputs[0].request_output[0].images, + images=outputs[0].images, elapsed_ms=elapsed_ms, ) finally: diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py index 0f0255b448..16239a1c52 100644 --- a/tests/e2e/offline_inference/test_sequence_parallel.py +++ b/tests/e2e/offline_inference/test_sequence_parallel.py @@ -130,7 +130,7 @@ def _run_inference( elapsed_ms = (time.time() - start) * 1000 return InferenceResult( - images=outputs[0].request_output[0].images, + images=outputs[0].request_output.images, elapsed_ms=elapsed_ms, ) finally: diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py index c36f064938..ff4d9b4017 100644 --- a/tests/e2e/offline_inference/test_stable_audio_model.py +++ b/tests/e2e/offline_inference/test_stable_audio_model.py @@ -57,7 +57,7 @@ def test_stable_audio_model(model_name: str): assert first_output.final_output_type == "image" assert hasattr(first_output, "request_output") and first_output.request_output - req_out = first_output.request_output[0] + req_out = first_output.request_output assert isinstance(req_out, OmniRequestOutput) assert req_out.final_output_type == "audio" assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py index 0cb0dafeba..77b2b3aaf2 100644 --- a/tests/e2e/offline_inference/test_t2i_model.py +++ b/tests/e2e/offline_inference/test_t2i_model.py @@ -28,7 +28,7 @@ models = ["Tongyi-MAI/Z-Image-Turbo", "Qwen/Qwen-Image"] elif current_omni_platform.is_rocm(): # TODO: When ROCm support is ready, remove this branch. - # vLLM V0.11.0 has issues running riverclouds/qwen_image_random + # Current upstream vLLM has issues running riverclouds/qwen_image_random # on ROCm models = ["Tongyi-MAI/Z-Image-Turbo"] @@ -62,13 +62,13 @@ def test_diffusion_model(model_name: str, run_level): num_outputs_per_prompt=2, ), ) - # Extract images from request_output[0]['images'] + # Extract images from request_output['images'] first_output = outputs[0] assert first_output.final_output_type == "image" if not hasattr(first_output, "request_output") or not first_output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - req_out = first_output.request_output[0] + req_out = first_output.request_output if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): raise ValueError("Invalid request_output structure or missing 'images' key") diff --git a/tests/e2e/offline_inference/test_t2v_model.py b/tests/e2e/offline_inference/test_t2v_model.py index 9565534bcc..94c9dedf74 100644 --- a/tests/e2e/offline_inference/test_t2v_model.py +++ b/tests/e2e/offline_inference/test_t2v_model.py @@ -51,7 +51,7 @@ def test_video_diffusion_model(model_name: str): if not hasattr(first_output, "request_output") or not first_output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - req_out = first_output.request_output[0] + req_out = first_output.request_output if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): raise ValueError("Invalid request_output structure or missing 'images' key") diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py index f2d1101255..efc0e43e86 100644 --- a/tests/e2e/offline_inference/test_teacache.py +++ b/tests/e2e/offline_inference/test_teacache.py @@ -68,13 +68,13 @@ def test_teacache(model_name: str): num_outputs_per_prompt=1, # Single output for speed ), ) - # Extract images from request_output[0]['images'] + # Extract images from request_output['images'] first_output = outputs[0] assert first_output.final_output_type == "image" if not hasattr(first_output, "request_output") or not first_output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - req_out = first_output.request_output[0] + req_out = first_output.request_output if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): raise ValueError("Invalid request_output structure or missing 'images' key") diff --git a/tests/e2e/offline_inference/test_vae_decode_parallelism.py b/tests/e2e/offline_inference/test_vae_decode_parallelism.py index 0bd3e9218a..cee76fac2e 100644 --- a/tests/e2e/offline_inference/test_vae_decode_parallelism.py +++ b/tests/e2e/offline_inference/test_vae_decode_parallelism.py @@ -72,6 +72,7 @@ def is_nextstep_model(model_name: str) -> bool: def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile, vae_patch_parallel_size=1): + m = None try: parallel_config = DiffusionParallelConfig( tensor_parallel_size=tp, @@ -105,7 +106,7 @@ def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile, ) end = time.perf_counter() first_output = outputs[0] - req_out = first_output.request_output[0] + req_out = first_output.request_output frames = req_out.images[0] if isinstance(frames, torch.Tensor): frames = frames.detach().cpu().numpy() @@ -115,7 +116,8 @@ def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile, cost = (end - start) * 1000 return frames, cost finally: - m.close() + if m is not None: + m.close() cleanup_dist_env_and_memory() diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py index 279710195d..9d9db16a40 100644 --- a/tests/e2e/offline_inference/test_zimage_parallelism.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -68,7 +68,7 @@ def _extract_single_image(outputs) -> Image.Image: if not hasattr(first_output, "request_output") or not first_output.request_output: raise ValueError("No request_output found in OmniRequestOutput") - req_out = first_output.request_output[0] + req_out = first_output.request_output if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): raise ValueError("Invalid request_output structure or missing 'images' key") diff --git a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml index c11de3e5d6..f3a8ecf79c 100644 --- a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml +++ b/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml @@ -1,4 +1,4 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. # The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090). # This config is optimized for CI e2e tests. diff --git a/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml index d51f7a5c8f..9fa5930c8e 100644 --- a/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml +++ b/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml @@ -1,4 +1,4 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. # The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090). # This config is optimized for CI e2e tests. diff --git a/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml index de7e9d901f..9feaaff67a 100644 --- a/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml +++ b/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml @@ -1,9 +1,9 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. # The following config is verified with 2 * Intel Arc Pro B60 XPU. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true # Run this stage in a separate process devices: "0" # Visible devices for this stage @@ -35,7 +35,7 @@ stage_args: detokenize: True repetition_penalty: 1.1 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "1" @@ -67,7 +67,7 @@ stage_args: stop_token_ids: [8294] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "2" diff --git a/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml index b362a86678..49de8325a3 100644 --- a/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml +++ b/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml @@ -6,7 +6,7 @@ # The following config is verified with 8 * Intel Arc Pro B60 XPU. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "0,1,2,3" max_batch_size: 1 @@ -41,7 +41,7 @@ stage_args: repetition_penalty: 1.05 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "4" max_batch_size: 1 @@ -75,7 +75,7 @@ stage_args: stop_token_ids: [2150] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "5" max_batch_size: 1 diff --git a/tests/engine/test_async_omni_engine_input.py b/tests/engine/test_async_omni_engine_input.py new file mode 100644 index 0000000000..b2d2d9a9e5 --- /dev/null +++ b/tests/engine/test_async_omni_engine_input.py @@ -0,0 +1,63 @@ +from unittest.mock import Mock + +import pytest +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import EngineCoreRequest + +from vllm_omni.engine import OmniEngineCoreRequest +from vllm_omni.engine.async_omni_engine import AsyncOmniEngine + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _make_engine_core_request() -> EngineCoreRequest: + return EngineCoreRequest( + request_id="req-1", + prompt_token_ids=[1, 1, 1], + mm_features=None, + sampling_params=SamplingParams(max_tokens=8), + pooling_params=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) + + +def test_build_add_request_message_preserves_additional_information(): + engine = object.__new__(AsyncOmniEngine) + params = SamplingParams(max_tokens=8) + engine.default_sampling_params_list = [params] + engine.stage_metadata = [{"stage_type": "llm"}] + engine.supported_tasks = ("speech",) + + input_processor = Mock() + input_processor.process_inputs.return_value = _make_engine_core_request() + engine.input_processor = input_processor + + output_processor = Mock() + engine.output_processors = [output_processor] + + prompt = { + "prompt_token_ids": [1, 1, 1], + "additional_information": { + "text": ["hello world"], + "speaker": ["vivian"], + }, + } + + msg = engine._build_add_request_message( + request_id="req-1", + prompt=prompt, + sampling_params_list=[params], + final_stage_id=0, + arrival_time=0.0, + ) + + request = msg["prompt"] + assert isinstance(request, OmniEngineCoreRequest) + assert request.external_req_id == "req-1" + assert request.additional_information is not None + assert request.additional_information.entries["text"].list_data == ["hello world"] + assert request.additional_information.entries["speaker"].list_data == ["vivian"] + output_processor.add_request.assert_called_once() diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 743e36a05e..35fccce124 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -10,6 +10,7 @@ import base64 import io from argparse import Namespace +from types import SimpleNamespace import pytest from fastapi.testclient import TestClient @@ -115,7 +116,10 @@ class FakeAsyncOmni: """Fake AsyncOmni that yields a single diffusion output.""" def __init__(self): - self.stage_list = ["llm", "diffusion"] + self.stage_configs = [ + SimpleNamespace(stage_type="llm"), + SimpleNamespace(stage_type="diffusion"), + ] self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()] self.captured_sampling_params_list = None self.captured_prompt = None @@ -129,21 +133,25 @@ async def generate(self, prompt, request_id, sampling_params_list): @pytest.fixture def mock_async_diffusion(mocker: MockerFixture): - """Mock AsyncOmniDiffusion instance that returns fake images""" - mock = mocker.Mock() - mock.is_running = True # For health endpoint - mock.check_health = mocker.AsyncMock() # For LLM mode health check + """Mock diffusion engine that matches the current async-generator API.""" - async def generate(**kwargs): - # Return n PIL images wrapped in result object - n = kwargs["sampling_params_list"][0].num_outputs_per_prompt - mock.captured_sampling_params_list = kwargs["sampling_params_list"] - mock.captured_prompt = kwargs["prompt"] - images = [Image.new("RGB", (64, 64), color="blue") for _ in range(n)] - return MockGenerationResult(images) + class MockAsyncDiffusion: + def __init__(self) -> None: + self.is_running = True + self.check_health = mocker.AsyncMock() + self.captured_sampling_params_list = None + self.captured_prompt = None + self.generate_calls = 0 - mock.generate = mocker.AsyncMock(side_effect=generate) - return mock + async def generate(self, **kwargs): + self.generate_calls += 1 + n = kwargs["sampling_params_list"][0].num_outputs_per_prompt + self.captured_sampling_params_list = kwargs["sampling_params_list"] + self.captured_prompt = kwargs["prompt"] + images = [Image.new("RGB", (64, 64), color="blue") for _ in range(n)] + yield MockGenerationResult(images) + + return MockAsyncDiffusion() @pytest.fixture @@ -159,7 +167,7 @@ def test_client(mock_async_diffusion): # Set up app state with diffusion engine app.state.engine_client = mock_async_diffusion app.state.diffusion_engine = mock_async_diffusion # Also set for health endpoint - app.state.stage_configs = [{"stage_type": "diffusion"}] + app.state.stage_configs = [SimpleNamespace(stage_type="diffusion")] app.state.diffusion_model_name = "Qwen/Qwen-Image" # For models endpoint app.state.args = Namespace( default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}', @@ -180,7 +188,32 @@ def async_omni_test_client(): app.include_router(router) app.state.engine_client = FakeAsyncOmni() - app.state.stage_configs = [{"stage_type": "llm"}, {"stage_type": "diffusion"}] + app.state.stage_configs = [ + SimpleNamespace(stage_type="llm"), + SimpleNamespace(stage_type="diffusion"), + ] + app.state.args = Namespace( + default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', + max_generated_image_size=4096, # 64*64 + ) + return TestClient(app) + + +@pytest.fixture +def async_omni_stage_configs_only_client(): + """Create test client with refactored AsyncOmni compatibility surface only.""" + from fastapi import FastAPI + + from vllm_omni.entrypoints.openai.api_server import router + + app = FastAPI() + app.include_router(router) + + engine = FakeAsyncOmni() + assert not hasattr(engine, "stage_list") + app.state.engine_client = engine + # Intentionally do not populate app.state.stage_configs. Refactored + # AsyncOmni exposes stage_configs on the engine instance. app.state.args = Namespace( default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}', max_generated_image_size=4096, # 64*64 @@ -291,6 +324,45 @@ def test_generate_images_async_omni_sampling_params(async_omni_test_client): assert captured[1].seed == 7 +def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_only_client): + """Regression: image generation accepts refactored AsyncOmni without stage_list.""" + response = async_omni_stage_configs_only_client.post( + "/v1/images/generations", + json={ + "prompt": "a castle", + "n": 1, + "size": "256x256", + "seed": 11, + }, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["data"]) == 1 + engine = async_omni_stage_configs_only_client.app.state.engine_client + captured = engine.captured_sampling_params_list + assert captured is not None + assert len(captured) == 2 + assert captured[1].seed == 11 + + +def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only_client): + """Regression: image edits accepts refactored AsyncOmni without stage_list.""" + img_bytes = make_test_image_bytes((16, 16)) + response = async_omni_stage_configs_only_client.post( + "/v1/images/edits", + files=[("image", img_bytes)], + data={ + "prompt": "edit me", + "size": "auto", + }, + ) + assert response.status_code == 200 + engine = async_omni_stage_configs_only_client.app.state.engine_client + captured = engine.captured_sampling_params_list + assert captured is not None + assert len(captured) == 2 + + def test_generate_multiple_images(test_client): """Test generating multiple images""" response = test_client.post( @@ -538,13 +610,12 @@ def test_parameters_passed_through(test_client, mock_async_diffusion): ) assert response.status_code == 200 - # Ensure generate() was called exactly once - mock_async_diffusion.generate.assert_awaited_once() - call_kwargs = mock_async_diffusion.generate.call_args[1]["sampling_params_list"][0] - assert call_kwargs.num_inference_steps == 100 - assert call_kwargs.guidance_scale == 7.5 - assert call_kwargs.true_cfg_scale == 3.0 - assert call_kwargs.seed == 42 + assert mock_async_diffusion.generate_calls == 1 + captured = mock_async_diffusion.captured_sampling_params_list[0] + assert captured.num_inference_steps == 100 + assert captured.guidance_scale == 7.5 + assert captured.true_cfg_scale == 3.0 + assert captured.seed == 42 def test_model_field_omitted_works(test_client): diff --git a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py index b93592b398..fa4c1e195d 100644 --- a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py +++ b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py @@ -63,9 +63,9 @@ def mock_engine_client( default_other_params, mocker: MockerFixture, ): - """Create mock engine client with stage_list and default_sampling_params_list.""" + """Create mock engine client with stage_configs and default_sampling_params_list.""" engine_client = mocker.MagicMock() - engine_client.stage_list = [mock_comprehension_stage, mock_other_stage] + engine_client.stage_configs = [mock_comprehension_stage, mock_other_stage] engine_client.default_sampling_params_list = [ default_comprehension_params, default_other_params, @@ -311,7 +311,7 @@ def test_get_comprehension_stage_index_finds_second_stage(mocker: MockerFixture) comprehension.is_comprehension = True instance.engine_client = mocker.MagicMock() - instance.engine_client.stage_list = [other, comprehension] + instance.engine_client.stage_configs = [other, comprehension] assert instance._get_comprehension_stage_index() == 1 @@ -328,7 +328,7 @@ def test_get_comprehension_stage_index_raises_when_not_found(mocker: MockerFixtu stage2.is_comprehension = False instance.engine_client = mocker.MagicMock() - instance.engine_client.stage_list = [stage1, stage2] + instance.engine_client.stage_configs = [stage1, stage2] with pytest.raises(ValueError, match="No comprehension stage"): instance._get_comprehension_stage_index() diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index e909923172..67abd7617b 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -487,7 +487,7 @@ class TestTTSMethods: def speech_server(self, mocker: MockerFixture): mock_engine_client = mocker.MagicMock() mock_engine_client.errored = False - mock_engine_client.stage_list = None + mock_engine_client.stage_configs = [] mock_engine_client.tts_max_instructions_length = None mock_models = mocker.MagicMock() mock_models.is_base_model.return_value = True @@ -499,7 +499,7 @@ def speech_server(self, mocker: MockerFixture): def test_is_tts_detection_no_stage(self, speech_server): """Test TTS model detection when no TTS stage exists.""" - # Fixture creates server with stage_list = None -> _is_tts should be False + # Fixture creates server with stage_configs = [] -> _is_tts should be False assert speech_server._is_tts is False assert speech_server._tts_stage is None @@ -511,9 +511,9 @@ def test_is_tts_detection_with_tts_stage(self, mocker: MockerFixture): # Create a TTS stage mock_stage = mocker.MagicMock() - mock_stage.model_stage = "qwen3_tts" + mock_stage.engine_args.model_stage = "qwen3_tts" mock_stage.tts_args = {} - mock_engine_client.stage_list = [mock_stage] + mock_engine_client.stage_configs = [mock_stage] mock_models = mocker.MagicMock() mock_models.is_base_model.return_value = True @@ -614,7 +614,7 @@ def test_load_supported_speakers(self, mocker: MockerFixture): """Test _load_supported_speakers.""" mock_engine_client = mocker.MagicMock() mock_engine_client.errored = False - mock_engine_client.stage_list = None + mock_engine_client.stage_configs = [] # Mock talker_config with mixed-case speaker names mock_talker_config = mocker.MagicMock() @@ -757,7 +757,7 @@ def test_max_instructions_length_cli_override(self, mocker: MockerFixture): """Test CLI override (stored in engine_client) takes highest priority.""" mock_engine_client = mocker.MagicMock() mock_engine_client.errored = False - mock_engine_client.stage_list = None + mock_engine_client.stage_configs = [] # CLI override is stored in engine_client mock_engine_client.tts_max_instructions_length = 1000 mock_models = mocker.MagicMock() @@ -781,9 +781,9 @@ def test_max_instructions_length_stage_config(self, mocker: MockerFixture): # Mock stage with tts_args mock_stage = mocker.MagicMock() - mock_stage.model_stage = "qwen3_tts" + mock_stage.engine_args.model_stage = "qwen3_tts" mock_stage.tts_args = {"max_instructions_length": 750} - mock_engine_client.stage_list = [mock_stage] + mock_engine_client.stage_configs = [mock_stage] server = OmniOpenAIServingSpeech( engine_client=mock_engine_client, @@ -804,9 +804,9 @@ def test_max_instructions_length_cli_overrides_stage_config(self, mocker: Mocker # Mock stage with tts_args mock_stage = mocker.MagicMock() - mock_stage.model_stage = "qwen3_tts" + mock_stage.engine_args.model_stage = "qwen3_tts" mock_stage.tts_args = {"max_instructions_length": 750} - mock_engine_client.stage_list = [mock_stage] + mock_engine_client.stage_configs = [mock_stage] server = OmniOpenAIServingSpeech( engine_client=mock_engine_client, @@ -820,7 +820,7 @@ def test_validate_instructions_length_uses_cached_value(self, mocker: MockerFixt """Test instructions length validation uses cached _max_instructions_length.""" mock_engine_client = mocker.MagicMock() mock_engine_client.errored = False - mock_engine_client.stage_list = None + mock_engine_client.stage_configs = [] # CLI override with max length of 10 characters mock_engine_client.tts_max_instructions_length = 10 mock_models = mocker.MagicMock() @@ -1034,15 +1034,12 @@ class TestAsyncOmniSupportedTasks: @pytest.mark.asyncio async def test_tts_only_no_generate_task(self): """TTS-only models (audio output, no text) should not include 'generate'.""" - from unittest.mock import MagicMock + from types import SimpleNamespace from vllm_omni.entrypoints.async_omni import AsyncOmni omni = AsyncOmni.__new__(AsyncOmni) - omni.output_modalities = [None, "audio"] - stage = MagicMock() - stage.is_comprehension = False - omni.stage_list = [stage] + omni.engine = SimpleNamespace(supported_tasks=("speech",)) tasks = await omni.get_supported_tasks() assert "generate" not in tasks assert "speech" in tasks @@ -1050,15 +1047,12 @@ async def test_tts_only_no_generate_task(self): @pytest.mark.asyncio async def test_omni_model_includes_generate(self): """Models with text output (e.g. Qwen3-Omni) should include 'generate'.""" - from unittest.mock import MagicMock + from types import SimpleNamespace from vllm_omni.entrypoints.async_omni import AsyncOmni omni = AsyncOmni.__new__(AsyncOmni) - omni.output_modalities = ["text", None, "audio"] - stage = MagicMock() - stage.is_comprehension = True - omni.stage_list = [stage] + omni.engine = SimpleNamespace(supported_tasks=("generate", "speech")) tasks = await omni.get_supported_tasks() assert "generate" in tasks diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index 0f905d9e0b..f8a8a1e4b4 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -44,7 +44,7 @@ def __init__(self, videos, audios=None, sample_rate=None): class FakeAsyncOmni: def __init__(self): - self.stage_list = ["diffusion"] + self.stage_configs = [SimpleNamespace(stage_type="diffusion")] self.captured_prompt = None self.captured_sampling_params_list = None diff --git a/tests/entrypoints/test_async_omni.py b/tests/entrypoints/test_async_omni.py new file mode 100644 index 0000000000..c65baa3ceb --- /dev/null +++ b/tests/entrypoints/test_async_omni.py @@ -0,0 +1,65 @@ +from types import SimpleNamespace + +import pytest +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels + +from vllm_omni.entrypoints.async_omni import AsyncOmni + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.mark.asyncio +async def test_get_supported_tasks_returns_engine_supported_tasks(): + omni = object.__new__(AsyncOmni) + omni.engine = SimpleNamespace(supported_tasks=("generate", "speech")) + + supported_tasks = await omni.get_supported_tasks() + + assert supported_tasks == ("generate", "speech") + + +def test_model_config_and_vllm_config_forward_from_comprehension_stage(): + model_config = SimpleNamespace(model="Qwen/Qwen3-TTS") + vllm_config = SimpleNamespace(model_config=model_config) + renderer = SimpleNamespace(name="renderer") + input_processor = SimpleNamespace(renderer=renderer) + io_processor = SimpleNamespace(name="io-processor") + omni = object.__new__(AsyncOmni) + omni.engine = SimpleNamespace( + stage_clients=[SimpleNamespace(is_comprehension=False), SimpleNamespace(is_comprehension=True)], + stage_vllm_configs=[None, vllm_config], + ) + omni.input_processor = input_processor + omni.io_processor = io_processor + + assert omni.vllm_config is vllm_config + assert omni.model_config is model_config + assert omni.renderer is renderer + assert omni.input_processor is input_processor + assert omni.io_processor is io_processor + + +def test_openai_serving_models_can_consume_async_omni_compat_attrs(): + model_config = SimpleNamespace(model="Qwen/Qwen3-TTS", max_model_len=32768) + vllm_config = SimpleNamespace(model_config=model_config) + renderer = SimpleNamespace(name="renderer") + input_processor = SimpleNamespace(renderer=renderer) + io_processor = SimpleNamespace(name="io-processor") + omni = object.__new__(AsyncOmni) + omni.engine = SimpleNamespace( + stage_clients=[SimpleNamespace(is_comprehension=True)], + stage_vllm_configs=[vllm_config], + ) + omni.input_processor = input_processor + omni.io_processor = io_processor + + serving_models = OpenAIServingModels( + engine_client=omni, + base_model_paths=[BaseModelPath(name="tts-model", model_path="Qwen/Qwen3-TTS")], + ) + + assert serving_models.model_config is model_config + assert serving_models.renderer is renderer + assert serving_models.io_processor is io_processor + assert serving_models.input_processor is input_processor diff --git a/tests/entrypoints/test_async_omni_diffusion_config.py b/tests/entrypoints/test_async_omni_diffusion_config.py index ed205032ee..06939a7d15 100644 --- a/tests/entrypoints/test_async_omni_diffusion_config.py +++ b/tests/entrypoints/test_async_omni_diffusion_config.py @@ -3,87 +3,51 @@ import pytest -from vllm_omni.entrypoints import utils as utils_module -from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.engine.async_omni_engine import AsyncOmniEngine pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -MODEL = "riverclouds/qwen_image_random" - -def _noop_inline_engine(self, model, stage_config, kwargs): - self._inline_diffusion = False - self._inline_engine = None - - -def test_default_stage_config_includes_cache_backend(monkeypatch): - """Ensure cache_backend/cache_config are preserved in default diffusion stage.""" - monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) - monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None) - monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) - monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) - monkeypatch.setattr(AsyncOmni, "_init_inline_diffusion_engine", _noop_inline_engine) - - omni = AsyncOmni( - model=MODEL, - cache_backend="cache_dit", - cache_config='{"Fn_compute_blocks": 2}', - vae_use_slicing=True, - ulysses_degree=2, - ) - - stage_cfg = omni.stage_configs[0] - engine_args = stage_cfg.engine_args - - assert engine_args.get("cache_backend") == "cache_dit" - cache_config = engine_args.get("cache_config") - assert cache_config["Fn_compute_blocks"] == 2 - assert engine_args.get("vae_use_slicing") is True - parallel_config = engine_args.get("parallel_config") - if hasattr(parallel_config, "get"): - ulysses_degree = parallel_config.get("ulysses_degree") - else: - ulysses_degree = getattr(parallel_config, "ulysses_degree", None) - assert ulysses_degree == 2 - - -def test_default_cache_config_used_when_missing(monkeypatch): - """Ensure default cache_config is applied when cache_backend is set.""" - monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) - monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None) - monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) - monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) - monkeypatch.setattr(AsyncOmni, "_init_inline_diffusion_engine", _noop_inline_engine) - - omni = AsyncOmni( - model=MODEL, - cache_backend="cache_dit", - ) - - engine_args = omni.stage_configs[0].engine_args - cache_config = engine_args.get("cache_config") +def test_default_stage_config_includes_cache_backend(): + """Ensure cache knobs survive the default diffusion-stage builder.""" + stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg( + { + "cache_backend": "cache_dit", + "cache_config": '{"Fn_compute_blocks": 2}', + "vae_use_slicing": True, + "ulysses_degree": 2, + } + )[0] + + engine_args = stage_cfg["engine_args"] + assert stage_cfg["stage_type"] == "diffusion" + assert engine_args["cache_backend"] == "cache_dit" + assert engine_args["cache_config"]["Fn_compute_blocks"] == 2 + assert engine_args["vae_use_slicing"] is True + assert engine_args["parallel_config"].ulysses_degree == 2 + assert engine_args["model_stage"] == "diffusion" + + +def test_default_cache_config_used_when_missing(): + """Ensure default cache_config is synthesized when only backend is given.""" + stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg( + { + "cache_backend": "cache_dit", + } + )[0] + + cache_config = stage_cfg["engine_args"]["cache_config"] assert cache_config is not None assert cache_config["Fn_compute_blocks"] == 1 -def test_default_stage_devices_from_sequence_parallel(monkeypatch): - """Ensure devices list reflects sequence parallel size when no parallel_config is provided.""" - monkeypatch.setattr(utils_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: []) - monkeypatch.setattr(utils_module, "resolve_model_config_path", lambda model: None) - monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None) - monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None) - monkeypatch.setattr(AsyncOmni, "_init_inline_diffusion_engine", _noop_inline_engine) - - omni = AsyncOmni( - model=MODEL, - ulysses_degree=2, - ring_degree=2, - ) +def test_default_stage_devices_from_sequence_parallel(): + """Ensure runtime devices reflect computed diffusion world size.""" + stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg( + { + "ulysses_degree": 2, + "ring_degree": 2, + } + )[0] - stage_cfg = omni.stage_configs[0] - runtime = stage_cfg.runtime - if hasattr(runtime, "get"): - devices = runtime.get("devices") - else: - devices = getattr(runtime, "devices", None) - assert devices == "0,1,2,3" + assert stage_cfg["runtime"]["devices"] == "0,1,2,3" diff --git a/tests/entrypoints/test_omni_diffusion.py b/tests/entrypoints/test_omni_diffusion.py deleted file mode 100644 index 9e555fa85c..0000000000 --- a/tests/entrypoints/test_omni_diffusion.py +++ /dev/null @@ -1,1431 +0,0 @@ -import uuid -import warnings -from queue import Empty, Queue -from typing import Any - -import pytest -import torch -from pytest_mock import MockerFixture - -from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK -from vllm_omni.inputs.data import OmniDiffusionSamplingParams - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - -# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. -warnings.filterwarnings( - "ignore", - message=r"builtin type SwigPy.*has no __module__ attribute", - category=DeprecationWarning, -) - -MODEL = "riverclouds/qwen_image_random" - - -class _FakeStageRequestStats: - """Fake StageRequestStats object with necessary attributes aligned with real StageRequestStats.""" - - def __init__(self, **kwargs): - # Required fields (with defaults for testing) - self.batch_id = kwargs.get("batch_id", 0) - self.batch_size = kwargs.get("batch_size", 1) - self.num_tokens_in = kwargs.get("num_tokens_in", 0) - self.num_tokens_out = kwargs.get("num_tokens_out", 1) - self.stage_gen_time_ms = kwargs.get("stage_gen_time_ms", 10.0) - self.rx_transfer_bytes = kwargs.get("rx_transfer_bytes", 0) - self.rx_decode_time_ms = kwargs.get("rx_decode_time_ms", 0.0) - self.rx_in_flight_time_ms = kwargs.get("rx_in_flight_time_ms", 0.0) - self.stage_stats = kwargs.get("stage_stats", None) - - # Optional fields - self.stage_id = kwargs.get("stage_id", None) - self.final_output_type = kwargs.get("final_output_type", None) - self.request_id = kwargs.get("request_id", None) - self.postprocess_time_ms = kwargs.get("postprocess_time_ms", 0.0) - self.diffusion_metrics = kwargs.get("diffusion_metrics", None) - self.audio_generated_frames = kwargs.get("audio_generated_frames", 0) - - # Allow additional attributes for flexibility - for key, value in kwargs.items(): - if not hasattr(self, key): - setattr(self, key, value) - - -class _FakeEngineArgs(dict): - """Fake engine args that can be used both as object attributes and as **kwargs.""" - - def __init__(self, args_dict: dict[str, Any]): - super().__init__(args_dict) - # Add required attributes if not present - if "model_stage" not in self: - self["model_stage"] = None - if "engine_output_type" not in self: - self["engine_output_type"] = None - # Also set as attributes for object-style access - for key, value in self.items(): - setattr(self, key, value) - - -class _FakeStageConfig: - """Fake stage config object that mimics the real stage config structure.""" - - def __init__(self, config_dict: dict[str, Any]): - # engine_args needs to work both as object (for OmniStage) and as dict (for **kwargs) - engine_args_dict = config_dict.get("engine_args", {}) - self.engine_args = _FakeEngineArgs(engine_args_dict) - self.final_output = config_dict.get("final_output", False) - self.final_output_type = config_dict.get("final_output_type", None) - self.stage_id = config_dict.get("stage_id", 0) - # Store original dict for reference - self._config_dict = config_dict - - -class _FakeQueue: - """Fake queue using standard library Queue to replace mp.Queue.""" - - def __init__(self, maxsize=0): - self._queue = Queue(maxsize=maxsize) - - def put(self, item): - self._queue.put(item) - - def put_nowait(self, item): - self._queue.put_nowait(item) - - def get(self, timeout=None): - if timeout is None: - return self._queue.get() - return self._queue.get(timeout=timeout) - - def get_nowait(self): - return self._queue.get_nowait() - - def empty(self): - return self._queue.empty() - - -class _FakeZmqQueue: - """Fake ZmqQueue that wraps _FakeQueue and matches ZmqQueue interface.""" - - def __init__( - self, - ctx=None, - socket_type=None, - *, - bind: str | None = None, - connect: str | None = None, - recv_timeout_ms: int | None = None, - send_timeout_ms: int | None = None, - ): - """Initialize fake ZMQ queue with same signature as real ZmqQueue.""" - self._queue = _FakeQueue(maxsize=0) - # Determine endpoint from bind or connect - path = bind if bind is not None else connect - self.endpoint = path or f"fake://zmq-endpoint-{id(self)}" - self._recv_timeout_ms = recv_timeout_ms - self._send_timeout_ms = send_timeout_ms - - def put(self, obj: Any) -> None: - """Send an object to the queue.""" - self._queue.put(obj) - - def put_nowait(self, obj: Any) -> None: - """Send an object to the queue without blocking.""" - self._queue.put_nowait(obj) - - def get(self, timeout: float | None = None) -> Any: - """Receive an object from the queue with optional timeout in seconds.""" - return self._queue.get(timeout=timeout) - - def get_nowait(self) -> Any: - """Receive an object from the queue without blocking.""" - return self._queue.get_nowait() - - def empty(self) -> bool: - """Check if the queue is empty without blocking.""" - return self._queue.empty() - - def close(self) -> None: - """Close the queue.""" - pass - - -class _FakeStage: - """Lightweight Stage stub for multi-process pipeline version with queue support.""" - - def __init__(self, mocker: MockerFixture, config, stage_init_timeout: int = 300): - # Handle both dict and object configs - if isinstance(config, dict): - config = _FakeStageConfig(config) - self.config = config - self.stage_config = config - self.engine = None - self.engine_outputs = None - # Set attributes that OmniStage expects - self.stage_id = getattr(config, "stage_id", 0) - self.engine_args = config.engine_args - self.model_stage = getattr(config.engine_args, "model_stage", None) - self.stage_type = "diffusion" - # set default sampling params - self.default_sampling_params = OmniDiffusionSamplingParams(num_inference_steps=1) - # Allow configuring final_output and final_output_type - self.final_output = config.final_output if hasattr(config, "final_output") else False - self.final_output_type = getattr(config, "final_output_type", None) - # Configurable processing logic, default returns placeholder - processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"]) - self._processed_input = processed_input - # Queue references (set by attach_queues) - self._in_q = None - self._out_q = None - self._proc = None # Mock process reference - self._stage_init_timeout = max(0, int(stage_init_timeout)) - # NOTE: pytest mocker fixture object is injected into test functions and classes - # with specific naming (e.g. `class TestXXX`, `def test_XXX` ) automatically, - # but **must be provided** when initializing this helper cls - self._mocker = mocker - - def attach_queues(self, in_q, out_q): - """Attach input and output queues.""" - self._in_q = in_q - self._out_q = out_q - - def init_stage_worker( - self, - model: str, - *, - is_async: bool = False, - shm_threshold_bytes: int = 65536, - ctx=None, - batch_timeout: int = 10, - **kwargs, - ): - """Mock init_stage_worker: don't start real process, just send stage_ready message.""" - # Create a mock process object - self._proc = self._mocker.MagicMock() - self._proc.start = self._mocker.MagicMock() - self._proc.join = self._mocker.MagicMock() - self._proc.is_alive = self._mocker.MagicMock(return_value=False) - self._proc.terminate = self._mocker.MagicMock() - # Send stage_ready message to output queue - if self._out_q is not None: - try: - self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id}) - except Exception: - pass - - def stop_stage_worker(self): - """Mock stop_stage_worker: clean up queue references.""" - if self._in_q is not None: - try: - self._in_q.put_nowait(SHUTDOWN_TASK) - except Exception: - pass - - def submit(self, payload: dict[str, Any]): - """Submit task to input queue.""" - if self._in_q is not None: - self._in_q.put(payload) - - def try_collect(self) -> Any: - """Non-blocking collect from output queue.""" - if self._out_q is None: - return None - try: - return self._out_q.get_nowait() - except Empty: - return None - - def set_engine_outputs(self, outputs): - """Set engine outputs for the stage.""" - self.engine_outputs = outputs - - def process_engine_inputs(self, stage_list, prompts): - """Process engine inputs: return preset processed result.""" - return self._processed_input - - -class _FakeEngine: - """Lightweight Engine stub: provides generate iterator output.""" - - def __init__(self, outputs: list[Any]): - self._outputs = outputs - - def generate(self, prompts, sampling_params): - # Record the most recent prompts for outer assertions - self._last_prompts = prompts - # Simplified: return preset list at once, ensuring iterability - yield from self._outputs - - -@pytest.fixture -def fake_stage_config(): - return { - # Don't include 'model' in engine_args since it's passed separately - "engine_args": {}, - "final_output": True, - "final_output_type": "text", - # Second stage will use processed_input to verify the chain - "processed_input": ["processed-by-stage"], - } - - -def _setup_engine_mocks(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture): - """Helper function to set up common engine mocks. - fixture objects `monkeypatch` and `mocker` must be passed in. - """ - fake_engine = mocker.MagicMock() - # Add necessary attributes to fake_engine - fake_engine.tokenizer = mocker.MagicMock() - fake_engine.log_stats = False - fake_engine.vllm_config = mocker.MagicMock() - fake_engine.vllm_config.model_config = mocker.MagicMock() - fake_engine.vllm_config.model_config.io_processor_plugin = None - fake_engine.get_supported_tasks = mocker.MagicMock(return_value=[]) - fake_engine.model_config = mocker.MagicMock() - fake_engine.model_config.io_processor_plugin = None - # Add registry with resolve_model_cls method - fake_registry = mocker.MagicMock() - fake_registry.resolve_model_cls = mocker.MagicMock(return_value=(mocker.MagicMock(), "test_arch")) - fake_engine.model_config.registry = fake_registry - fake_engine.vllm_config.model_config.registry = fake_registry - - monkeypatch.setattr( - "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args", - lambda **kw: fake_engine, - raising=False, - ) - - # Mock model_config.registry.resolve_model_cls to return a tuple - # Use a real class instead of MagicMock to avoid inspect.getsource issues - class FakeModelClass: - pass - - monkeypatch.setattr( - "vllm.model_executor.model_loader.utils.get_model_architecture", - lambda model_config: (FakeModelClass, "test_arch"), - raising=False, - ) - - monkeypatch.setattr( - "vllm.model_executor.model_loader.utils._get_model_architecture", - lambda model_config: (FakeModelClass, "test_arch"), - raising=False, - ) - - # Mock try_create_mm_pooling_model_cls to return the class as-is - monkeypatch.setattr( - "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls", - lambda model_cls: model_cls, - raising=False, - ) - - # Mock _enable_processor_cache to return False - monkeypatch.setattr( - "vllm.multimodal.cache._enable_processor_cache", - lambda model_config, mm_registry: False, - raising=False, - ) - - # Mock get_io_processor to return None - monkeypatch.setattr( - "vllm.plugins.io_processors.get_io_processor", - lambda vllm_config, io_processor_plugin: None, - raising=False, - ) - - -def _setup_multiprocessing_mocks(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture): - """Helper function to set up multiprocessing mocks. - fixture objects `monkeypatch` and `mocker` must be passed in. - """ - import multiprocessing as mp - - # Mock Process - fake_process_class = mocker.MagicMock() - fake_process_instance = mocker.MagicMock() - fake_process_instance.start = mocker.MagicMock() - fake_process_instance.join = mocker.MagicMock() - fake_process_instance.is_alive = mocker.MagicMock(return_value=True) - fake_process_instance.terminate = mocker.MagicMock() - fake_process_class.return_value = fake_process_instance - - # Mock get_context to return a context with Queue that returns _FakeQueue - fake_ctx = mocker.MagicMock() - fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize) - fake_ctx.Process = fake_process_class - - def _mock_get_context(method): - return fake_ctx - - monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False) - monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) - - # Mock ZmqQueue to use _FakeZmqQueue - monkeypatch.setattr( - "vllm_omni.entrypoints.zmq_utils.ZmqQueue", - _FakeZmqQueue, - raising=False, - ) - # Also mock where ZmqQueue is imported/used - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.ZmqQueue", - _FakeZmqQueue, - raising=False, - ) - - -def _setup_ipc_mocks(monkeypatch: pytest.MonkeyPatch): - """Helper function to set up IPC function mocks.""" - - # Mock _encode: simple serialization - def _fake_encode(obj, threshold, obj_key, shm_key): - return {obj_key: obj} - - # Mock _load: extract object from result - def _fake_load(result, obj_key, shm_key): - return result.get(obj_key) - - # Mock _set: calculate serialization size - def _fake_set(obj): - return str(obj).encode() - - monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False) - monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False) - monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False) - - -def _setup_log_mocks(monkeypatch: pytest.MonkeyPatch): - """Helper function to set up logging and stats mocks.""" - # Mock OrchestratorMetrics to be a simple class that doesn't require file operations - - class _FakeOrchestratorMetrics: - def __init__(self, num_stages, enable_stats, wall_start_ts): - self.num_stages = num_stages - self.enable_stats = enable_stats - self.stage_first_ts = [None] * num_stages - self.stage_last_ts = [None] * num_stages - self.e2e_done = set() - - def on_stage_metrics(self, stage_id, req_id, metrics): - pass - - def on_finalize_request(self, stage_id, req_id, start_ts): - self.e2e_done.add(req_id) - - def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm): - pass - - def build_and_log_summary(self, final_stage_id): - return "Fake summary" - - monkeypatch.setattr( - "vllm_omni.entrypoints.omni.OrchestratorMetrics", - _FakeOrchestratorMetrics, - raising=False, - ) - - -def _setup_connector_mocks(monkeypatch, mocker, omni_module=None): - """Helper function to set up connector mocks for stage-to-stage forwarding. - - If omni_module is provided, mocks directly on the module. Otherwise, uses string path. - """ - - # Mock initialize_orchestrator_connectors to return fake connectors - def _fake_initialize_orchestrator_connectors(config_path, worker_backend=None, shm_threshold_bytes=None): - # Create fake connectors for all stage-to-stage edges - # Each connector is just a mock object that will be passed to try_send_via_connector - fake_connectors = {} - # Add connectors for common edges (0->1, 1->2, etc.) - for i in range(10): # Support up to 10 stages - fake_connectors[(str(i), str(i + 1))] = mocker.MagicMock() - return None, fake_connectors - - if omni_module is not None: - # Mock directly on the omni module where it's used (after import) - monkeypatch.setattr(omni_module, "initialize_orchestrator_connectors", _fake_initialize_orchestrator_connectors) - else: - # Mock via string path (before import) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni.initialize_orchestrator_connectors", - _fake_initialize_orchestrator_connectors, - raising=False, - ) - - -def _setup_connector_adapter_mock(monkeypatch, omni_module): - """Helper function to mock try_send_via_connector on the omni module. - - This must be called AFTER importing omni module, to mock the function where it's actually used. - """ - - # Mock try_send_via_connector to always succeed - def _fake_try_send_via_connector( - connector, - stage_id, - next_stage_id, - req_id, - next_inputs, - sampling_params, - original_prompt, - next_stage_queue_submit_fn, - metrics, - ): - # Simulate successful send by calling the submit function - task = { - "request_id": req_id, - "engine_inputs": next_inputs, - "sampling_params": sampling_params, - } - next_stage_queue_submit_fn(task) - return True - - # Mock directly on the omni module where it's used - monkeypatch.setattr(omni_module, "try_send_via_connector", _fake_try_send_via_connector) - - -@pytest.fixture(autouse=True) -def mock_get_config(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture): - """Auto-mock get_config and related model loading functions to avoid model path validation.""" - # CRITICAL: Mock tokenizer-related imports FIRST, before any module imports - # This prevents ImportError when async_omni is imported (which happens via omni_stage) - import sys - - fake_tokenizer = mocker.MagicMock() - fake_tokenizer.encode = mocker.MagicMock(return_value=[1, 2, 3]) - fake_tokenizer.decode = mocker.MagicMock(return_value="test") - - # Mock init_tokenizer_from_configs (used in async_omni) - def _mock_init_tokenizer_from_configs(model_config=None, **kwargs): - return fake_tokenizer - - # Strategy 1: Mock in the original location (vllm.transformers_utils.tokenizer) - # This works if the module hasn't been imported yet - monkeypatch.setattr( - "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs", - _mock_init_tokenizer_from_configs, - raising=False, - ) - - # Strategy 2: If the module is already in sys.modules, patch it directly - tokenizer_module_path = "vllm.transformers_utils.tokenizer" - if tokenizer_module_path in sys.modules: - tokenizer_module = sys.modules[tokenizer_module_path] - setattr(tokenizer_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) - - # CRITICAL: Mock length_from_prompt_token_ids_or_embeds BEFORE trying to mock async_omni - - # This is because async_omni imports processor.py, which imports this function at module level - # Mock length_from_prompt_token_ids_or_embeds (used in processor.py) - def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None): - # Return a reasonable default length - if prompt_token_ids is not None: - if isinstance(prompt_token_ids, list): - return len(prompt_token_ids) - elif hasattr(prompt_token_ids, "shape"): - return prompt_token_ids.shape[-1] if len(prompt_token_ids.shape) > 0 else 1 - if prompt_embeds is not None: - if hasattr(prompt_embeds, "shape"): - return prompt_embeds.shape[-2] if len(prompt_embeds.shape) > 1 else 1 - return 10 # Default length - - # Mock in vllm.utils - monkeypatch.setattr( - "vllm.utils.length_from_prompt_token_ids_or_embeds", - _mock_length_from_prompt_token_ids_or_embeds, - raising=False, - ) - # Also mock in processor module if it's imported - monkeypatch.setattr( - "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds", - _mock_length_from_prompt_token_ids_or_embeds, - raising=False, - ) - # If processor module is already imported, patch it directly - processor_module_path = "vllm_omni.engine.input_processor" - if processor_module_path in sys.modules: - processor_module = sys.modules[processor_module_path] - setattr( - processor_module, "length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds - ) - - # Strategy 3: Now mock async_omni AFTER length_from_prompt_token_ids_or_embeds is mocked - # This prevents ImportError when async_omni imports processor.py - monkeypatch.setattr( - "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", - _mock_init_tokenizer_from_configs, - raising=False, - ) - - # Strategy 4: If async_omni is already imported, patch it directly - async_omni_path = "vllm_omni.entrypoints.async_omni" - if async_omni_path in sys.modules: - async_omni_module = sys.modules[async_omni_path] - setattr(async_omni_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) - - # Now mock get_config and other functions - fake_hf_config = mocker.MagicMock() - fake_hf_config.model_type = "qwen2_5_omni" - - def _mock_get_config(model, **kwargs): - return fake_hf_config - - monkeypatch.setattr( - "vllm.transformers_utils.config.get_config", - _mock_get_config, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.get_config", - _mock_get_config, - raising=False, - ) - - # Mock transformers' cached_file to avoid downloading model configs - def _mock_cached_file(path_or_repo_id, *args, **kwargs): - import os - import tempfile - - fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json") - if not os.path.exists(fake_config_file): - with open(fake_config_file, "w") as f: - f.write('{"model_type": "qwen2_5_omni"}') - return fake_config_file - - monkeypatch.setattr( - "transformers.utils.hub.cached_file", - _mock_cached_file, - raising=False, - ) - monkeypatch.setattr( - "transformers.utils.hub.cached_files", - lambda path_or_repo_id, filenames, **kwargs: ( - [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None - ), - raising=False, - ) - - -def test_initialize_stage_configs_called_when_none( - monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config -): - """Test that stage configs are auto-loaded when stage_configs_path is None.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - cfg0 = dict(fake_stage_config) - cfg0["stage_id"] = 0 - cfg1 = dict(fake_stage_config) - cfg1["stage_id"] = 1 - return None, [ - _FakeStageConfig(cfg0), - _FakeStageConfig(cfg1), - ] - - # Remove modules from cache BEFORE setting mocks - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - # Set up mocks - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - # Mock load_and_resolve_stage_configs - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - - # Replace OmniStage - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - # Import the module after mocks are set - import vllm_omni.entrypoints.omni as omni_module - - # Patch the imported class in the module - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - # Verify: auto-loaded stage_configs and stage_list have consistent count - assert isinstance(omni.stage_configs, list) - assert len(omni.stage_configs) == 2 - assert len(omni.stage_list) == 2 - # Verify: each Stage is _FakeStage instance - for st in omni.stage_list: - assert isinstance(st, _FakeStage) - # Verify: queues are attached - for st in omni.stage_list: - assert st._in_q is not None - assert st._out_q is not None - # Verify: all stages are ready - assert len(omni._stages_ready) == 2 - - -def test_generate_raises_on_length_mismatch(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test that generate raises ValueError when sampling_params_list length doesn't match.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(fake_stage_config)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - with pytest.raises(ValueError): - omni.generate(prompts=["hi"], sampling_params_list=[]) - - -def test_generate_pipeline_and_final_outputs(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test multi-stage generation pipeline with queue polling.""" - stage_cfg0 = dict(fake_stage_config) - stage_cfg0["stage_id"] = 0 - stage_cfg1 = dict(fake_stage_config) - stage_cfg1["stage_id"] = 1 - stage_cfg1["processed_input"] = ["processed-for-stage-1"] - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - "vllm_omni.distributed.omni_connectors.adapter", - "vllm_omni.distributed.omni_connectors", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - # Apply connector and adapter mocks after importing omni module - _setup_connector_mocks(monkeypatch, mocker, omni_module) - _setup_connector_adapter_mock(monkeypatch, omni_module) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Simulate worker behavior: manually put results into output queues - # Note: We put results before calling generate, which simulates worker processes - # that have already completed. The polling loop will collect them in stage order. - # Stage 0 output (will be collected first) - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0, "text": "s0"}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - # Stage 1 output (will be collected after stage 0 forwards to it) - # Note: In real flow, stage 1 result would appear after stage 0 forwards, - # but for testing we pre-populate it. The polling loop processes stages - # in order, so stage 0 result will be collected first, then forwarded, - # then stage 1 result will be collected. - omni.stage_list[1]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 1, "text": "s1"}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - - sampling_params_list = [ - OmniDiffusionSamplingParams(num_inference_steps=1), - OmniDiffusionSamplingParams(num_inference_steps=1, max_sequence_length=10), - ] - prompts = ["hi"] - outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list) - - # Both stages have final_output=True, so should aggregate two OmniRequestOutput - assert len(outputs) == 2 - # Verify stage outputs are set - assert omni.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}] - assert omni.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}] - # Verify stage 0 input queue received the task - assert not omni.stage_list[0]._in_q.empty() - # Verify stage 1 received forwarded task (process_engine_inputs was called) - assert omni.stage_list[1].process_engine_inputs([], []) is not None - - -def test_generate_pipeline_with_batch_input(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test single-stage generation pipeline with multiple inputs in one batch.""" - stage_cfg0 = dict(fake_stage_config) - stage_cfg0["stage_id"] = 0 - stage_cfg0["final_output"] = False - stage_cfg1 = dict(fake_stage_config) - stage_cfg1["stage_id"] = 1 - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - "vllm_omni.distributed.omni_connectors.adapter", - "vllm_omni.distributed.omni_connectors", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - # Apply connector and adapter mocks after importing omni module - _setup_connector_mocks(monkeypatch, mocker, omni_module) - _setup_connector_adapter_mock(monkeypatch, omni_module) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Simulate worker behavior: manually put results into output queues - # Note: We put results before calling generate, which simulates worker processes - # that have already completed. The polling loop will collect them in stage order. - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0, "text": "s0"}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0, "text": "s0"}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - omni.stage_list[1]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 1}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - omni.stage_list[1]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 1}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - - outputs = omni.generate( - prompts=[ - { - "prompt": "hi", - "negative_prompt": "hi", - "multi_modal_data": {"image": ["dog.jpg", "cat.jpg"]}, - }, - { - "prompt": "hi", - "negative_prompt": "hi", - "multi_modal_data": {"image": ["dog.jpg", "cat.jpg"]}, - }, - ], - sampling_params_list=[ - OmniDiffusionSamplingParams(num_inference_steps=1), - OmniDiffusionSamplingParams(num_inference_steps=1), - ], - ) - - assert len(outputs) == 2 - - -def test_generate_no_final_output_returns_empty( - monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config -): - """Test that generate returns empty list when all stages have final_output=False.""" - stage_cfg0 = dict(fake_stage_config) - stage_cfg0["stage_id"] = 0 - stage_cfg0["final_output"] = False - stage_cfg1 = dict(fake_stage_config) - stage_cfg1["stage_id"] = 1 - stage_cfg1["final_output"] = False - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - "vllm_omni.distributed.omni_connectors.adapter", - "vllm_omni.distributed.omni_connectors", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - # Apply connector and adapter mocks after importing omni module - _setup_connector_mocks(monkeypatch, mocker, omni_module) - _setup_connector_adapter_mock(monkeypatch, omni_module) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Simulate worker behavior: put results into output queues - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - omni.stage_list[1]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 1}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - - outputs = omni.generate( - prompts=["p"], - sampling_params_list=[ - OmniDiffusionSamplingParams(num_inference_steps=1), - OmniDiffusionSamplingParams(num_inference_steps=1, max_sequence_length=10), - ], - ) - assert outputs == [] - - -def test_generate_sampling_params_none_use_default( - monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config -): - """Test that generate uses default sampling params when sampling_params_list is None.""" - stage_cfg0 = dict(fake_stage_config) - stage_cfg0["stage_id"] = 0 - stage_cfg0["final_output"] = False - stage_cfg1 = dict(fake_stage_config) - stage_cfg1["stage_id"] = 1 - stage_cfg1["final_output"] = False - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - "vllm_omni.distributed.omni_connectors.adapter", - "vllm_omni.distributed.omni_connectors", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - # Apply connector and adapter mocks after importing omni module - _setup_connector_mocks(monkeypatch, mocker, omni_module) - _setup_connector_adapter_mock(monkeypatch, omni_module) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Simulate worker behavior: put results into output queues - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - omni.stage_list[1]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 1}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - # Use the default sampling params - omni.generate(prompts=["p"], sampling_params_list=None) - - -def test_wait_for_stages_ready_timeout(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test that _wait_for_stages_ready handles timeout correctly.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(fake_stage_config)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - - # Create a stage that doesn't send stage_ready message - class _FakeStageNoReady(_FakeStage): - def init_stage_worker(self, *args, **kwargs): - # Don't send stage_ready message - self._proc = self._mocker.MagicMock() - self._proc.start = self._mocker.MagicMock() - self._proc.join = self._mocker.MagicMock() - self._proc.is_alive = self._mocker.MagicMock(return_value=False) - self._proc.terminate = self._mocker.MagicMock() - - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStageNoReady(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(mocker, cfg, **kwargs)) - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - - from vllm_omni.entrypoints.omni import Omni - - # Use very short timeout - omni = Omni(model=MODEL, init_timeout=0.01) - # Verify that no stages are ready - assert len(omni._stages_ready) == 0 - - -def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test that generate handles error messages from stages correctly.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(fake_stage_config)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Put error message in output queue - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "error": "test error", - } - ) - # Also put a valid result after error to allow the loop to complete - # (error handling continues the loop, so we need a valid result to finish) - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0, "text": "result"}], - "metrics": _FakeStageRequestStats(num_tokens_out=1, stage_gen_time_ms=10.0), - } - ) - - # Generate should handle error gracefully (log but continue) - sampling_params_list = [OmniDiffusionSamplingParams(num_inference_steps=1)] - outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list) - # Should return final output (error was logged but didn't stop processing) - assert isinstance(outputs, list) - # Since final_output=True, should have one output - assert len(outputs) == 1 - - -def test_close_sends_shutdown_signal(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test that close() sends shutdown signal to all input queues.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(fake_stage_config)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Call close - omni.close() - - # Verify shutdown signal (None) was sent to input queue - # Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe) - try: - shutdown_signal = omni.stage_list[0]._in_q.get_nowait() - assert shutdown_signal == SHUTDOWN_TASK - except Empty: - # If queue was already empty or only had stage_ready, that's also acceptable - # The important thing is that close() was called without error - pass - - # Verify stop_stage_worker was called (process should be set) - assert omni.stage_list[0]._proc is not None - - -# Tests below are for diffusion dtype normalization fix from the following: -# https://github.com/vllm-project/vllm-omni/pull/1391 -# In the future we should ensure dtypes are parsed in less hacky way and make -# these tests more atomic. -@pytest.mark.parametrize("dtype", ["float16", torch.float16]) -def test_dtype_normalization_valid_types( - monkeypatch, dtype: str | torch.dtype, mocker: MockerFixture, fake_stage_config -): - """Ensure Diffusion Config builder coerces valid types correctly.""" - - def _fake_loader(model: str, base_engine_args=None): - # Return not stage configs to fall back to the diffusion cfg builder - return [] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - - from vllm_omni.entrypoints.omni import Omni - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", - _fake_loader, - raising=False, - ) - - omni = Omni(model="any", init_timeout=1, dtype=dtype) - - # Dtype parsing being checked is on the diffusion path - assert len(omni.stage_configs) == 1 - assert omni.stage_configs[0]["stage_type"] == "diffusion" - - # Regardless of whether a str / dtype is passed, it resolves correctly - engine_args = omni.stage_configs[0]["engine_args"] - assert "dtype" in engine_args - assert isinstance(engine_args["dtype"], str) - assert engine_args["dtype"] == "float16" - - -def test_dtype_normalization_invalid_types(monkeypatch, mocker: MockerFixture, fake_stage_config): - """Ensure Diffusion Config builder correctly handles bad dtype overrides.""" - - def _fake_loader(model: str, base_engine_args=None): - # Return not stage configs to fall back to the diffusion cfg builder - return [] - - class NotATorchDtype: - pass - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - - from vllm_omni.entrypoints.omni import Omni - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_stage_configs_from_model", - _fake_loader, - raising=False, - ) - - # Raise TypeError if we get an unrecognized type - with pytest.raises(TypeError): - Omni(model="any", init_timeout=1, dtype=NotATorchDtype) diff --git a/tests/entrypoints/test_omni_input_preprocessor.py b/tests/entrypoints/test_omni_input_preprocessor.py deleted file mode 100644 index 422154bf96..0000000000 --- a/tests/entrypoints/test_omni_input_preprocessor.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest - -from vllm_omni.inputs.preprocess import OmniInputPreprocessor - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -def _make_preprocessor(monkeypatch): - preprocessor = object.__new__(OmniInputPreprocessor) - monkeypatch.setattr(preprocessor, "_truncate_inputs", lambda tokens, tokenization_kwargs=None: tokens) - monkeypatch.setattr( - preprocessor, - "_process_multimodal", - lambda *args, **kwargs: {"prompt_token_ids": [1, 2, 3]}, - ) - monkeypatch.setattr(preprocessor, "_tokenize_prompt", lambda prompt_text, tokenization_kwargs=None: [9, 8, 7]) - return preprocessor - - -def test_process_tokens_keeps_additional_information(monkeypatch): - preprocessor = _make_preprocessor(monkeypatch) - parsed = { - "prompt_token_ids": [1, 2, 3], - "prompt_embeds": "embeds", - "additional_information": {"task": ["tts"], "lang": ["auto"]}, - } - - inputs = OmniInputPreprocessor._process_tokens(preprocessor, parsed) - - assert inputs["prompt_token_ids"] == [1, 2, 3] - assert inputs["prompt_embeds"] == "embeds" - assert inputs["additional_information"] == {"task": ["tts"], "lang": ["auto"]} - - -def test_process_text_keeps_additional_information(monkeypatch): - preprocessor = _make_preprocessor(monkeypatch) - parsed = { - "prompt": "hello", - "prompt_embeds": "embeds", - "additional_information": {"speaker": ["alice"]}, - } - - inputs = OmniInputPreprocessor._process_text(preprocessor, parsed) - - assert inputs["prompt_token_ids"] == [9, 8, 7] - assert inputs["prompt_embeds"] == "embeds" - assert inputs["additional_information"] == {"speaker": ["alice"]} - - -def test_process_text_multimodal_skips_empty_payloads(monkeypatch): - preprocessor = _make_preprocessor(monkeypatch) - parsed = { - "prompt": "hello", - "multi_modal_data": {"image": "fake"}, - "prompt_embeds": None, - "additional_information": None, - } - - inputs = OmniInputPreprocessor._process_text(preprocessor, parsed) - - assert inputs["prompt_token_ids"] == [1, 2, 3] - assert "prompt_embeds" not in inputs - assert "additional_information" not in inputs diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py deleted file mode 100644 index b11b14c893..0000000000 --- a/tests/entrypoints/test_omni_llm.py +++ /dev/null @@ -1,1266 +0,0 @@ -import inspect -import uuid -import warnings -from queue import Empty, Queue -from typing import Any - -import pytest -from pytest_mock import MockerFixture -from vllm import SamplingParams - -from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - -# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies. -warnings.filterwarnings( - "ignore", - message=r"builtin type SwigPy.*has no __module__ attribute", - category=DeprecationWarning, -) - -MODEL = "riverclouds/qwen_image_random" - - -class _FakeEngineArgs(dict): - """Fake engine args that can be used both as object attributes and as **kwargs.""" - - def __init__(self, args_dict: dict[str, Any]): - super().__init__(args_dict) - # Add required attributes if not present - if "model_stage" not in self: - self["model_stage"] = None - if "engine_output_type" not in self: - self["engine_output_type"] = None - # Also set as attributes for object-style access - for key, value in self.items(): - setattr(self, key, value) - - -class _FakeStageConfig: - """Fake stage config object that mimics the real stage config structure.""" - - def __init__(self, config_dict: dict[str, Any]): - # engine_args needs to work both as object (for OmniStage) and as dict (for **kwargs) - engine_args_dict = config_dict.get("engine_args", {}) - self.engine_args = _FakeEngineArgs(engine_args_dict) - self.final_output = config_dict.get("final_output", False) - self.final_output_type = config_dict.get("final_output_type", None) - self.stage_id = config_dict.get("stage_id", 0) - # Store original dict for reference - self._config_dict = config_dict - - -class _FakeQueue: - """Fake queue using standard library Queue to replace mp.Queue.""" - - def __init__(self, maxsize=0): - self._queue = Queue(maxsize=maxsize) - - def put(self, item): - self._queue.put(item) - - def put_nowait(self, item): - self._queue.put_nowait(item) - - def get(self, timeout=None): - if timeout is None: - return self._queue.get() - return self._queue.get(timeout=timeout) - - def get_nowait(self): - return self._queue.get_nowait() - - def empty(self): - return self._queue.empty() - - -class _FakeZmqQueue: - """Fake ZmqQueue that wraps _FakeQueue and matches ZmqQueue interface.""" - - def __init__( - self, - ctx=None, - socket_type=None, - *, - bind: str | None = None, - connect: str | None = None, - recv_timeout_ms: int | None = None, - send_timeout_ms: int | None = None, - ): - """Initialize fake ZMQ queue with same signature as real ZmqQueue.""" - self._queue = _FakeQueue(maxsize=0) - # Determine endpoint from bind or connect - path = bind if bind is not None else connect - self.endpoint = path or f"fake://zmq-endpoint-{id(self)}" - self._recv_timeout_ms = recv_timeout_ms - self._send_timeout_ms = send_timeout_ms - - def put(self, obj: Any) -> None: - """Send an object to the queue.""" - self._queue.put(obj) - - def put_nowait(self, obj: Any) -> None: - """Send an object to the queue without blocking.""" - self._queue.put_nowait(obj) - - def get(self, timeout: float | None = None) -> Any: - """Receive an object from the queue with optional timeout in seconds.""" - return self._queue.get(timeout=timeout) - - def get_nowait(self) -> Any: - """Receive an object from the queue without blocking.""" - return self._queue.get_nowait() - - def empty(self) -> bool: - """Check if the queue is empty without blocking.""" - return self._queue.empty() - - def close(self) -> None: - """Close the queue.""" - pass - - -class _FakeStage: - """Lightweight Stage stub for multi-process pipeline version with queue support.""" - - def __init__(self, mocker: MockerFixture, config, stage_init_timeout: int = 300): - # Handle both dict and object configs - if isinstance(config, dict): - config = _FakeStageConfig(config) - self.config = config - self.stage_config = config - self.engine = None - self.engine_outputs = None - # Set attributes that OmniStage expects - self.stage_id = getattr(config, "stage_id", 0) - self.engine_args = config.engine_args - self.model_stage = getattr(config.engine_args, "model_stage", None) - self.stage_type = "llm" - # set default sampling params - self.default_sampling_params = SamplingParams(temperature=1.0) - # Allow configuring final_output and final_output_type - self.final_output = config.final_output if hasattr(config, "final_output") else False - self.final_output_type = getattr(config, "final_output_type", None) - # Configurable processing logic, default returns placeholder - processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"]) - self._processed_input = processed_input - # Queue references (set by attach_queues) - self._in_q = None - self._out_q = None - self._proc = None # Mock process reference - self._stage_init_timeout = max(0, int(stage_init_timeout)) - # NOTE: mocker fixture object **must be provided** - self._mocker = mocker - - def attach_queues(self, in_q, out_q): - """Attach input and output queues.""" - self._in_q = in_q - self._out_q = out_q - - def init_stage_worker( - self, - model: str, - *, - is_async: bool = False, - shm_threshold_bytes: int = 65536, - ctx=None, - batch_timeout: int = 10, - **kwargs, - ): - """Mock init_stage_worker: don't start real process, just send stage_ready message.""" - # Create a mock process object - self._proc = self._mocker.MagicMock() - self._proc.start = self._mocker.MagicMock() - self._proc.join = self._mocker.MagicMock() - self._proc.is_alive = self._mocker.MagicMock(return_value=False) - self._proc.terminate = self._mocker.MagicMock() - # Send stage_ready message to output queue - if self._out_q is not None: - try: - self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id}) - except Exception: - pass - - def stop_stage_worker(self): - """Mock stop_stage_worker: clean up queue references.""" - if self._in_q is not None: - try: - self._in_q.put_nowait(SHUTDOWN_TASK) - except Exception: - pass - - def submit(self, payload: dict[str, Any]): - """Submit task to input queue.""" - if self._in_q is not None: - self._in_q.put(payload) - - def try_collect(self) -> Any: - """Non-blocking collect from output queue.""" - if self._out_q is None: - return None - try: - return self._out_q.get_nowait() - except Empty: - return None - - def set_engine_outputs(self, outputs): - """Set engine outputs for the stage.""" - self.engine_outputs = outputs - - def process_engine_inputs(self, stage_list, prompts): - """Process engine inputs: return preset processed result.""" - return self._processed_input - - -class _FakeEngine: - """Lightweight Engine stub: provides generate iterator output.""" - - def __init__(self, outputs: list[Any]): - self._outputs = outputs - - def generate(self, prompts, sampling_params): - # Record the most recent prompts for outer assertions - self._last_prompts = prompts - # Simplified: return preset list at once, ensuring iterability - yield from self._outputs - - -@pytest.fixture -def fake_stage_config(): - return { - # Don't include 'model' in engine_args since it's passed separately - "engine_args": {}, - "final_output": True, - "final_output_type": "text", - # Second stage will use processed_input to verify the chain - "processed_input": ["processed-by-stage"], - } - - -def _setup_engine_mocks(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture): - """Helper function to set up common engine mocks.""" - fake_engine = mocker.MagicMock() - # Add necessary attributes to fake_engine - fake_engine.tokenizer = mocker.MagicMock() - fake_engine.log_stats = False - fake_engine.vllm_config = mocker.MagicMock() - fake_engine.vllm_config.model_config = mocker.MagicMock() - fake_engine.vllm_config.model_config.io_processor_plugin = None - fake_engine.get_supported_tasks = mocker.MagicMock(return_value=[]) - fake_engine.model_config = mocker.MagicMock() - fake_engine.model_config.io_processor_plugin = None - # Add registry with resolve_model_cls method - fake_registry = mocker.MagicMock() - fake_registry.resolve_model_cls = mocker.MagicMock(return_value=(mocker.MagicMock(), "test_arch")) - fake_engine.model_config.registry = fake_registry - fake_engine.vllm_config.model_config.registry = fake_registry - - monkeypatch.setattr( - "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args", - lambda **kw: fake_engine, - raising=False, - ) - - # Mock model_config.registry.resolve_model_cls to return a tuple - # Use a real class instead of MagicMock to avoid inspect.getsource issues - class FakeModelClass: - pass - - monkeypatch.setattr( - "vllm.model_executor.model_loader.utils.get_model_architecture", - lambda model_config: (FakeModelClass, "test_arch"), - raising=False, - ) - - monkeypatch.setattr( - "vllm.model_executor.model_loader.utils._get_model_architecture", - lambda model_config: (FakeModelClass, "test_arch"), - raising=False, - ) - - # Mock try_create_mm_pooling_model_cls to return the class as-is - monkeypatch.setattr( - "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls", - lambda model_cls: model_cls, - raising=False, - ) - - # Mock _enable_processor_cache to return False - monkeypatch.setattr( - "vllm.multimodal.cache._enable_processor_cache", - lambda model_config, mm_registry: False, - raising=False, - ) - - # Mock get_io_processor to return None - monkeypatch.setattr( - "vllm.plugins.io_processors.get_io_processor", - lambda vllm_config, io_processor_plugin: None, - raising=False, - ) - - -def _setup_multiprocessing_mocks(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture): - """Helper function to set up multiprocessing mocks.""" - import multiprocessing as mp - - # Mock Process - fake_process_class = mocker.MagicMock() - fake_process_instance = mocker.MagicMock() - fake_process_instance.start = mocker.MagicMock() - fake_process_instance.join = mocker.MagicMock() - fake_process_instance.is_alive = mocker.MagicMock(return_value=False) - fake_process_instance.terminate = mocker.MagicMock() - fake_process_class.return_value = fake_process_instance - - # Mock get_context to return a context with Queue that returns _FakeQueue - fake_ctx = mocker.MagicMock() - fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize) - fake_ctx.Process = fake_process_class - - def _mock_get_context(method): - return fake_ctx - - monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False) - monkeypatch.setattr(mp, "Process", fake_process_class, raising=False) - - # Mock ZmqQueue to use _FakeZmqQueue - monkeypatch.setattr( - "vllm_omni.entrypoints.zmq_utils.ZmqQueue", - _FakeZmqQueue, - raising=False, - ) - # Also mock where ZmqQueue is imported/used - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.ZmqQueue", - _FakeZmqQueue, - raising=False, - ) - - -def _setup_ipc_mocks(monkeypatch: pytest.MonkeyPatch): - """Helper function to set up IPC function mocks.""" - - # Mock _encode: simple serialization - def _fake_encode(obj, threshold, obj_key, shm_key): - return {obj_key: obj} - - # Mock _load: extract object from result - def _fake_load(result, obj_key, shm_key): - return result.get(obj_key) - - # Mock _set: calculate serialization size - def _fake_set(obj): - return str(obj).encode() - - monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False) - monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False) - monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False) - - -def _setup_log_mocks(monkeypatch: pytest.MonkeyPatch): - """Helper function to set up logging and stats mocks.""" - # Mock OrchestratorMetrics to be a simple class that doesn't require file operations - - class _FakeOrchestratorMetrics: - def __init__(self, num_stages, log_stats, wall_start_ts): - self.num_stages = num_stages - self.log_stats = log_stats - self.stage_first_ts = [None] * num_stages - self.stage_last_ts = [None] * num_stages - self.e2e_done = set() - - def on_stage_metrics(self, stage_id, req_id, metrics): - pass - - def on_finalize_request(self, stage_id, req_id, start_ts): - self.e2e_done.add(req_id) - - def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm): - pass - - def build_and_log_summary(self, final_stage_id): - return "Fake summary" - - monkeypatch.setattr( - "vllm_omni.entrypoints.omni.OrchestratorAggregator", - _FakeOrchestratorMetrics, - raising=False, - ) - - -def _setup_connector_mocks(monkeypatch, mocker, omni_module=None): - """Helper function to set up connector mocks for stage-to-stage forwarding. - - If omni_module is provided, mocks directly on the module. Otherwise, uses string path. - """ - - # Mock initialize_orchestrator_connectors to return fake connectors - def _fake_initialize_orchestrator_connectors(config_path, worker_backend=None, shm_threshold_bytes=None): - # Create fake connectors for all stage-to-stage edges - # Each connector is just a mock object that will be passed to try_send_via_connector - fake_connectors = {} - # Add connectors for common edges (0->1, 1->2, etc.) - for i in range(10): # Support up to 10 stages - fake_connectors[(str(i), str(i + 1))] = mocker.MagicMock() - return None, fake_connectors - - if omni_module is not None: - # Mock directly on the omni module where it's used (after import) - monkeypatch.setattr(omni_module, "initialize_orchestrator_connectors", _fake_initialize_orchestrator_connectors) - else: - # Mock via string path (before import) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni.initialize_orchestrator_connectors", - _fake_initialize_orchestrator_connectors, - raising=False, - ) - - -def _setup_connector_adapter_mock(monkeypatch, omni_module): - """Helper function to mock try_send_via_connector on the omni module. - - This must be called AFTER importing omni module, to mock the function where it's actually used. - """ - - # Mock try_send_via_connector to always succeed - def _fake_try_send_via_connector( - connector, - stage_id, - next_stage_id, - req_id, - next_inputs, - sampling_params, - original_prompt, - next_stage_queue_submit_fn, - metrics, - ): - # Simulate successful send by calling the submit function - task = { - "request_id": req_id, - "engine_inputs": next_inputs, - "sampling_params": sampling_params, - } - next_stage_queue_submit_fn(task) - return True - - # Mock directly on the omni module where it's used - monkeypatch.setattr(omni_module, "try_send_via_connector", _fake_try_send_via_connector) - - -@pytest.fixture(autouse=True) -def mock_get_config(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture): - """Auto-mock get_config and related model loading functions to avoid model path validation.""" - # CRITICAL: Mock tokenizer-related imports FIRST, before any module imports - # This prevents ImportError when async_omni is imported (which happens via omni_stage) - import sys - - fake_tokenizer = mocker.MagicMock() - fake_tokenizer.encode = mocker.MagicMock(return_value=[1, 2, 3]) - fake_tokenizer.decode = mocker.MagicMock(return_value="test") - - # Mock init_tokenizer_from_configs (used in async_omni) - def _mock_init_tokenizer_from_configs(model_config=None, **kwargs): - return fake_tokenizer - - # Strategy 1: Mock in the original location (vllm.transformers_utils.tokenizer) - # This works if the module hasn't been imported yet - monkeypatch.setattr( - "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs", - _mock_init_tokenizer_from_configs, - raising=False, - ) - - # Strategy 2: If the module is already in sys.modules, patch it directly - tokenizer_module_path = "vllm.transformers_utils.tokenizer" - if tokenizer_module_path in sys.modules: - tokenizer_module = sys.modules[tokenizer_module_path] - setattr(tokenizer_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) - - # CRITICAL: Mock length_from_prompt_token_ids_or_embeds BEFORE trying to mock async_omni - - # This is because async_omni imports processor.py, which imports this function at module level - # Mock length_from_prompt_token_ids_or_embeds (used in processor.py) - def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None): - # Return a reasonable default length - if prompt_token_ids is not None: - if isinstance(prompt_token_ids, list): - return len(prompt_token_ids) - elif hasattr(prompt_token_ids, "shape"): - return prompt_token_ids.shape[-1] if len(prompt_token_ids.shape) > 0 else 1 - if prompt_embeds is not None: - if hasattr(prompt_embeds, "shape"): - return prompt_embeds.shape[-2] if len(prompt_embeds.shape) > 1 else 1 - return 10 # Default length - - # Mock in vllm.utils - monkeypatch.setattr( - "vllm.utils.length_from_prompt_token_ids_or_embeds", - _mock_length_from_prompt_token_ids_or_embeds, - raising=False, - ) - # Also mock in processor module if it's imported - monkeypatch.setattr( - "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds", - _mock_length_from_prompt_token_ids_or_embeds, - raising=False, - ) - # If processor module is already imported, patch it directly - processor_module_path = "vllm_omni.engine.input_processor" - if processor_module_path in sys.modules: - processor_module = sys.modules[processor_module_path] - setattr( - processor_module, "length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds - ) - - # Strategy 3: Now mock async_omni AFTER length_from_prompt_token_ids_or_embeds is mocked - # This prevents ImportError when async_omni imports processor.py - monkeypatch.setattr( - "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", - _mock_init_tokenizer_from_configs, - raising=False, - ) - - # Strategy 4: If async_omni is already imported, patch it directly - async_omni_path = "vllm_omni.entrypoints.async_omni" - if async_omni_path in sys.modules: - async_omni_module = sys.modules[async_omni_path] - setattr(async_omni_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs) - - # Now mock get_config and other functions - fake_hf_config = mocker.MagicMock() - fake_hf_config.model_type = "qwen2_5_omni" - - def _mock_get_config(model, **kwargs): - return fake_hf_config - - monkeypatch.setattr( - "vllm.transformers_utils.config.get_config", - _mock_get_config, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.get_config", - _mock_get_config, - raising=False, - ) - - # Mock transformers' cached_file to avoid downloading model configs - def _mock_cached_file(path_or_repo_id, *args, **kwargs): - import os - import tempfile - - fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json") - if not os.path.exists(fake_config_file): - with open(fake_config_file, "w") as f: - f.write('{"model_type": "qwen2_5_omni"}') - return fake_config_file - - monkeypatch.setattr( - "transformers.utils.hub.cached_file", - _mock_cached_file, - raising=False, - ) - monkeypatch.setattr( - "transformers.utils.hub.cached_files", - lambda path_or_repo_id, filenames, **kwargs: ( - [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None - ), - raising=False, - ) - - -def test_initialize_stage_configs_called_when_none( - monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config -): - """Test that stage configs are auto-loaded when stage_configs_path is None.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - cfg0 = dict(fake_stage_config) - cfg0["stage_id"] = 0 - cfg1 = dict(fake_stage_config) - cfg1["stage_id"] = 1 - return None, [ - _FakeStageConfig(cfg0), - _FakeStageConfig(cfg1), - ] - - # Remove modules from cache BEFORE setting mocks - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - # Set up mocks - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - # Mock load_and_resolve_stage_configs - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - - # Replace OmniStage - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - # Import the module after mocks are set - import vllm_omni.entrypoints.omni as omni_module - - # Patch the imported function and class in the module - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - # Apply connector and adapter mocks after importing omni module - _setup_connector_mocks(monkeypatch, mocker, omni_module) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - # Verify: auto-loaded stage_configs and stage_list have consistent count - assert isinstance(omni.stage_configs, list) - assert len(omni.stage_configs) == 2 - assert len(omni.stage_list) == 2 - # Verify: each Stage is _FakeStage instance - for st in omni.stage_list: - assert isinstance(st, _FakeStage) - # Verify: queues are attached - for st in omni.stage_list: - assert st._in_q is not None - assert st._out_q is not None - # Verify: all stages are ready - assert len(omni._stages_ready) == 2 - - -def test_generate_raises_on_length_mismatch(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test that generate raises ValueError when sampling_params_list length doesn't match.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - cfg0 = dict(fake_stage_config) - cfg0["stage_id"] = 0 - return None, [_FakeStageConfig(cfg0)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - with pytest.raises(ValueError): - omni.generate(prompts=["hi"], sampling_params_list=[]) - - -def test_generate_pipeline_and_final_outputs(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test multi-stage generation pipeline with queue polling.""" - stage_cfg0 = dict(fake_stage_config) - stage_cfg0["stage_id"] = 0 - stage_cfg1 = dict(fake_stage_config) - stage_cfg1["stage_id"] = 1 - stage_cfg1["processed_input"] = ["processed-for-stage-1"] - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - # Apply connector and adapter mocks after importing omni module - _setup_connector_mocks(monkeypatch, mocker, omni_module) - _setup_connector_adapter_mock(monkeypatch, omni_module) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Simulate worker behavior: manually put results into output queues - # Note: We put results before calling generate, which simulates worker processes - # that have already completed. The polling loop will collect them in stage order. - # Stage 0 output (will be collected first) - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0, "text": "s0"}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - # Stage 1 output (will be collected after stage 0 forwards to it) - # Note: In real flow, stage 1 result would appear after stage 0 forwards, - # but for testing we pre-populate it. The polling loop processes stages - # in order, so stage 0 result will be collected first, then forwarded, - # then stage 1 result will be collected. - omni.stage_list[1]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 1, "text": "s1"}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - - sampling_params_list = [ - SamplingParams(temperature=0.7), - SamplingParams(temperature=0.8), - ] - prompts = ["hi"] - outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list) - - # Both stages have final_output=True, so should aggregate two OmniRequestOutput - assert len(outputs) == 2 - # Verify stage outputs are set - assert omni.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}] - assert omni.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}] - # Verify stage 0 input queue received the task - assert not omni.stage_list[0]._in_q.empty() - # Verify stage 1 received forwarded task (process_engine_inputs was called) - assert omni.stage_list[1].process_engine_inputs([], []) is not None - - -def test_generate_no_final_output_returns_empty( - monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config -): - """Test that generate returns empty list when all stages have final_output=False.""" - stage_cfg0 = dict(fake_stage_config) - stage_cfg0["stage_id"] = 0 - stage_cfg0["final_output"] = False - stage_cfg1 = dict(fake_stage_config) - stage_cfg1["stage_id"] = 1 - stage_cfg1["final_output"] = False - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - # Apply connector and adapter mocks after importing omni module - _setup_connector_mocks(monkeypatch, mocker, omni_module) - _setup_connector_adapter_mock(monkeypatch, omni_module) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Simulate worker behavior: put results into output queues - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - omni.stage_list[1]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 1}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - - outputs = omni.generate( - prompts=["p"], - sampling_params_list=[ - SamplingParams(temperature=0.7), - SamplingParams(temperature=0.8), - ], - ) - assert outputs == [] - - -def test_generate_sampling_params_none_use_default( - monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config -): - """Test that generate uses default sampling params when sampling_params_list is None.""" - stage_cfg0 = dict(fake_stage_config) - stage_cfg0["stage_id"] = 0 - stage_cfg0["final_output"] = False - stage_cfg1 = dict(fake_stage_config) - stage_cfg1["stage_id"] = 1 - stage_cfg1["final_output"] = False - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - return None, [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - # Apply connector and adapter mocks after importing omni module - _setup_connector_mocks(monkeypatch, mocker, omni_module) - _setup_connector_adapter_mock(monkeypatch, omni_module) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Simulate worker behavior: put results into output queues - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - omni.stage_list[1]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 1}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - # Use the default sampling params - omni.generate(prompts=["p"], sampling_params_list=None) - - -def test_wait_for_stages_ready_timeout(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test that _wait_for_stages_ready handles timeout correctly.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - cfg0 = dict(fake_stage_config) - cfg0["stage_id"] = 0 - return None, [_FakeStageConfig(cfg0)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - - # Create a stage that doesn't send stage_ready message - class _FakeStageNoReady(_FakeStage): - def init_stage_worker(self, *args, **kwargs): - # Don't send stage_ready message - self._proc = self._mocker.MagicMock() - self._proc.start = self._mocker.MagicMock() - self._proc.join = self._mocker.MagicMock() - self._proc.is_alive = self._mocker.MagicMock(return_value=False) - self._proc.terminate = self._mocker.MagicMock() - - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStageNoReady(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(mocker, cfg, **kwargs)) - - from vllm_omni.entrypoints.omni import Omni - - # Use very short timeout - omni = Omni(model=MODEL, init_timeout=0.01) - # Verify that no stages are ready - assert len(omni._stages_ready) == 0 - - -def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test that generate handles error messages from stages correctly.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - cfg0 = dict(fake_stage_config) - cfg0["stage_id"] = 0 - return None, [_FakeStageConfig(cfg0)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - - # Mock uuid.uuid4() to return a predictable value for request ID generation - test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000") - monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid) - monkeypatch.setattr(omni_module, "uuid", uuid) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Generate the expected request ID format: "0_" - expected_request_id = f"0_{test_uuid}" - - # Put error message in output queue - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "error": "test error", - } - ) - # Also put a valid result after error to allow the loop to complete - # (error handling continues the loop, so we need a valid result to finish) - omni.stage_list[0]._out_q.put_nowait( - { - "request_id": expected_request_id, - "engine_outputs": [{"stage": 0, "text": "result"}], - "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0}, - } - ) - - # Generate should handle error gracefully (log but continue) - sampling_params_list = [SamplingParams(temperature=0.7)] - outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list) - # Should return final output (error was logged but didn't stop processing) - assert isinstance(outputs, list) - # Since final_output=True, should have one output - assert len(outputs) == 1 - - -def test_close_sends_shutdown_signal(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): - """Test that close() sends shutdown signal to all input queues.""" - - def _fake_loader( - model: str, - stage_configs_path: str | None = None, - base_engine_args: dict | None = None, - default_stage_cfg_factory=None, - ): - cfg0 = dict(fake_stage_config) - cfg0["stage_id"] = 0 - return None, [_FakeStageConfig(cfg0)] - - import sys - - for module_name in [ - "vllm_omni.entrypoints.utils", - "vllm_omni.entrypoints.omni", - "vllm_omni.entrypoints.omni_stage", - ]: - if module_name in sys.modules: - del sys.modules[module_name] - - _setup_engine_mocks(monkeypatch, mocker) - _setup_multiprocessing_mocks(monkeypatch, mocker) - _setup_ipc_mocks(monkeypatch) - _setup_log_mocks(monkeypatch) - _setup_connector_mocks(monkeypatch, mocker) - - monkeypatch.setattr( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - _fake_loader, - raising=False, - ) - monkeypatch.setattr( - "vllm_omni.entrypoints.omni_stage.OmniStage", - lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), - raising=False, - ) - - import vllm_omni.entrypoints.omni as omni_module - - monkeypatch.setattr(omni_module, "load_and_resolve_stage_configs", _fake_loader) - monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) - - from vllm_omni.entrypoints.omni import Omni - - omni = Omni(model=MODEL, init_timeout=1) - - # Call close - omni.close() - - # Verify shutdown signal (None) was sent to input queue - # Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe) - try: - shutdown_signal = omni.stage_list[0]._in_q.get_nowait() - assert shutdown_signal == SHUTDOWN_TASK - except Empty: - # If queue was already empty or only had stage_ready, that's also acceptable - # The important thing is that close() was called without error - pass - - # Verify stop_stage_worker was called (process should be set) - assert omni.stage_list[0]._proc is not None - - -# --------------------------------------------------------------------------- -# Signature compatibility tests — catch upstream API drift early -# --------------------------------------------------------------------------- - -_INHERITANCE_PAIRS: list[tuple[str, str, str, str]] = [ - ( - "vllm.entrypoints.llm", - "LLM", - "vllm_omni.entrypoints.omni_llm", - "OmniLLM", - ), -] - - -def _import_class(module_path: str, class_name: str): - mod = __import__(module_path, fromlist=[class_name]) - return getattr(mod, class_name) - - -@pytest.mark.parametrize( - "up_mod,up_cls,omni_mod,omni_cls", - _INHERITANCE_PAIRS, - ids=[f"{pair[3]}({pair[1]})" for pair in _INHERITANCE_PAIRS], -) -def test_overridden_method_signatures_compatible(up_mod: str, up_cls: str, omni_mod: str, omni_cls: str): - """All params accepted by a base-class method must also be accepted by - the overriding method in the vllm-omni subclass. This prevents - TypeError at runtime when upstream callers pass new arguments.""" - BaseCls = _import_class(up_mod, up_cls) - OmniCls = _import_class(omni_mod, omni_cls) - - failures: list[str] = [] - for name in dir(OmniCls): - if name.startswith("__") and name != "__init__": - continue - omni_method = getattr(OmniCls, name, None) - base_method = getattr(BaseCls, name, None) - if not (callable(omni_method) and callable(base_method)): - continue - if omni_method is base_method: - continue - try: - base_sig = inspect.signature(base_method) - omni_sig = inspect.signature(omni_method) - except (ValueError, TypeError): - continue - - omni_has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in omni_sig.parameters.values()) - - base_params = base_sig.parameters - omni_params = set(omni_sig.parameters.keys()) - - missing = [] - for pname, param in base_params.items(): - if pname in omni_params: - continue - if omni_has_var_keyword and param.kind in ( - inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ): - continue - missing.append(pname) - - if missing: - failures.append(f"{omni_cls}.{name}() missing params {sorted(missing)}; base={base_sig}, omni={omni_sig}") - - assert not failures, f"Signature mismatches found ({len(failures)}):\n" + "\n".join(f" - {f}" for f in failures) diff --git a/tests/entrypoints/test_omni_stage_diffusion_config.py b/tests/entrypoints/test_omni_stage_diffusion_config.py deleted file mode 100644 index f464c55fd6..0000000000 --- a/tests/entrypoints/test_omni_stage_diffusion_config.py +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm_omni.entrypoints.omni_stage import _build_od_config - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -def test_build_od_config_includes_diffusion_fields(): - engine_args = { - "cache_backend": "cache_dit", - "cache_config": {"Fn_compute_blocks": 2}, - "vae_use_slicing": True, - } - od_config = _build_od_config(engine_args, model="dummy-model") - - assert od_config["model"] == "dummy-model" - assert od_config["cache_backend"] == "cache_dit" - assert od_config["cache_config"]["Fn_compute_blocks"] == 2 - assert od_config["vae_use_slicing"] is True - - -def test_build_od_config_respects_explicit_config(): - engine_args = { - "od_config": {"cache_backend": "tea_cache"}, - "cache_backend": "cache_dit", - } - od_config = _build_od_config(engine_args, model="dummy-model") - assert od_config == {"cache_backend": "tea_cache"} diff --git a/tests/perf/scripts/run_benchmark.py b/tests/perf/scripts/run_benchmark.py index ed2fab83e3..704d47f9bb 100644 --- a/tests/perf/scripts/run_benchmark.py +++ b/tests/perf/scripts/run_benchmark.py @@ -50,11 +50,14 @@ def create_unique_server_params(configs: list[dict[str, Any]]) -> list[tuple[str for config in configs: test_name = config["test_name"] model = config["server_params"]["model"] - stage_config_name = config["server_params"]["stage_config_name"] - stage_config_path = str(Path(__file__).parent.parent / "stage_configs" / stage_config_name) - delete = config["server_params"].get("delete", None) - update = config["server_params"].get("update", None) - stage_config_path = modify_stage(stage_config_path, update, delete) + stage_config_name = config["server_params"].get("stage_config_name") + if stage_config_name: + stage_config_path = str(Path(__file__).parent.parent / "stage_configs" / stage_config_name) + delete = config["server_params"].get("delete", None) + update = config["server_params"].get("update", None) + stage_config_path = modify_stage(stage_config_path, update, delete) + else: + stage_config_path = None server_param = (test_name, model, stage_config_path) if server_param not in seen: @@ -98,7 +101,10 @@ def omni_server(request): print(f"Starting OmniServer with test: {test_name}, model: {model}") - with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server: + server_args = ["--stage-init-timeout", "120"] + if stage_config_path: + server_args = ["--stage-configs-path", stage_config_path] + server_args + with OmniServer(model, server_args) as server: server.test_name = test_name print("OmniServer started successfully") yield server diff --git a/tests/perf/stage_configs/qwen3_omni.yaml b/tests/perf/stage_configs/qwen3_omni.yaml index 01ca7d9fbe..7b347c7f1b 100644 --- a/tests/perf/stage_configs/qwen3_omni.yaml +++ b/tests/perf/stage_configs/qwen3_omni.yaml @@ -7,7 +7,7 @@ async_chunk: false stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "0" max_batch_size: 64 @@ -38,7 +38,7 @@ stage_args: repetition_penalty: 1.05 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "1" max_batch_size: 64 @@ -69,7 +69,7 @@ stage_args: stop_token_ids: [2150] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "1" max_batch_size: 64 diff --git a/tests/perf/stage_configs/qwen3_tts.yaml b/tests/perf/stage_configs/qwen3_tts.yaml index 4ba4e6e83e..5de7f13f12 100644 --- a/tests/perf/stage_configs/qwen3_tts.yaml +++ b/tests/perf/stage_configs/qwen3_tts.yaml @@ -7,6 +7,7 @@ async_chunk: true stage_args: - stage_id: 0 stage_type: llm + is_comprehension: true runtime: devices: "0" max_batch_size: 4 diff --git a/tests/perf/tests/test.json b/tests/perf/tests/test.json index 65ef6588d9..dc76d1dc4e 100644 --- a/tests/perf/tests/test.json +++ b/tests/perf/tests/test.json @@ -152,8 +152,7 @@ { "test_name": "test_qwen3_tts", "server_params": { - "model": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", - "stage_config_name": "qwen3_tts.yaml" + "model": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" }, "benchmark_params": [ { diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index a6c0f3dd3b..1c08a1543d 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -16,7 +16,6 @@ # alternatives like msgpack or pydantic that are already in use in vLLM. Only # add to this list if absolutely necessary and after careful security review. ALLOWED_FILES = { - "vllm_omni/entrypoints/omni_llm.py", "tests/e2e/offline_inference/utils.py", "tests/utils.py", "vllm_omni/diffusion/distributed/group_coordinator.py", diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py index c9d7fcf2c1..b093272d2f 100644 --- a/vllm_omni/__init__.py +++ b/vllm_omni/__init__.py @@ -24,10 +24,7 @@ from vllm_omni.transformers_utils import configs as _configs # noqa: F401, E402 from .config import OmniModelConfig -from .entrypoints.async_omni import AsyncOmni - -# Main entry points -from .entrypoints.omni import Omni +from .entrypoints import AsyncOmni, Omni from .version import __version__, __version_tuple__ # isort:skip diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index e457c3d999..f4af0f49b6 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -49,10 +49,11 @@ class StageType(str, Enum): @dataclass class StageConfig: - """Per-stage configuration — pipeline-structure fields only. + """Per-stage configuration from pipeline YAML. - Engine params (gpu_memory_utilization, tp_size, etc.) come from CLI, - NOT from this class. + Topology fields (stage_id, input_sources, etc.) define the DAG. + Engine and runtime defaults come from the YAML; CLI overrides take + precedence via ``runtime_overrides``. """ # Identity @@ -71,6 +72,14 @@ class StageConfig: hf_config_name: str | None = None is_comprehension: bool = False + # Per-stage engine args from pipeline YAML (defaults) + yaml_engine_args: dict[str, Any] = field(default_factory=dict) + # Per-stage runtime config from pipeline YAML (devices, max_batch_size) + yaml_runtime: dict[str, Any] = field(default_factory=dict) + # Pass-through fields from pipeline YAML (default_sampling_params, + # output_connectors, input_connectors, tts_args, etc.) + yaml_extras: dict[str, Any] = field(default_factory=dict) + # Runtime overrides (populated from CLI, not from pipeline YAML) runtime_overrides: dict[str, Any] = field(default_factory=dict) @@ -80,11 +89,11 @@ def to_omegaconf(self) -> Any: Returns: OmegaConf DictConfig with stage configuration in legacy format. """ - # Build engine_args dict with required fields - engine_args: dict[str, Any] = { - "model_stage": self.model_stage, - } + # Start with YAML engine_args defaults + engine_args: dict[str, Any] = dict(self.yaml_engine_args) + # Overlay topology-level fields + engine_args["model_stage"] = self.model_stage if self.worker_type: engine_args["worker_type"] = self.worker_type if self.scheduler_cls: @@ -92,19 +101,22 @@ def to_omegaconf(self) -> Any: if self.hf_config_name: engine_args["hf_config_name"] = self.hf_config_name - # Apply runtime overrides (CLI args) + # CLI overrides take precedence over YAML defaults for key, value in self.runtime_overrides.items(): if key not in ("devices", "max_batch_size"): engine_args[key] = value - # Build runtime config - runtime: dict[str, Any] = { - "process": True, - "max_batch_size": self.runtime_overrides.get("max_batch_size", 1), - } + # Build runtime config from YAML defaults + CLI overrides + runtime: dict[str, Any] = dict(self.yaml_runtime) + runtime.setdefault("process", True) + runtime.setdefault("max_batch_size", self.runtime_overrides.get("max_batch_size", 1)) if "devices" in self.runtime_overrides: runtime["devices"] = self.runtime_overrides["devices"] + # Inject max_num_seqs from max_batch_size (legacy compat) + max_batch_size = int(runtime.get("max_batch_size", 1) or 1) + engine_args.setdefault("max_num_seqs", max_batch_size) + # Build full config dict config_dict: dict[str, Any] = { "stage_id": self.stage_id, @@ -120,6 +132,10 @@ def to_omegaconf(self) -> Any: if self.custom_process_input_func: config_dict["custom_process_input_func"] = self.custom_process_input_func + # Pass through extra YAML fields (default_sampling_params, + # output_connectors, input_connectors, tts_args, etc.) + config_dict.update(self.yaml_extras) + return create_config(config_dict) @@ -249,6 +265,15 @@ def create_from_model( if errors: logger.warning(f"Pipeline validation warnings for {model}: {errors}") + # Inject pipeline-wide async_chunk into ALL stages' engine_args. + # The legacy loader (load_stage_configs_from_yaml) sets async_chunk + # on every stage so that build_engine_args_dict() can inject the + # stage_connector_spec. AsyncOmniEngine.__init__ also reads it + # from stage_configs[0].engine_args.async_chunk. + if pipeline.async_chunk: + for stage in pipeline.stages: + stage.yaml_engine_args.setdefault("async_chunk", True) + # Apply CLI overrides result: list[StageConfig] = [] for stage in pipeline.stages: @@ -348,6 +373,25 @@ def _load_pipeline(cls, model: str, trust_remote_code: bool = True) -> ModelPipe return cls._parse_pipeline_yaml(pipeline_path, model_type) + # Keys consumed as explicit StageConfig fields — everything else is + # passed through via yaml_extras. + _KNOWN_STAGE_KEYS: set[str] = { + "stage_id", + "model_stage", + "stage_type", + "input_sources", + "engine_input_source", + "custom_process_input_func", + "final_output", + "final_output_type", + "worker_type", + "scheduler_cls", + "hf_config_name", + "is_comprehension", + "engine_args", + "runtime", + } + @classmethod def _parse_pipeline_yaml(cls, path: Path, model_type: str) -> ModelPipeline: """Parse a pipeline YAML file. @@ -375,27 +419,65 @@ def _parse_pipeline_yaml(cls, path: Path, model_type: str) -> ModelPipeline: input_sources = [] input_sources = list(input_sources) + # Extract per-stage engine_args and runtime dicts + raw_ea = stage_data.get("engine_args", None) + yaml_engine_args = to_dict(raw_ea) if raw_ea is not None else {} + raw_rt = stage_data.get("runtime", None) + yaml_runtime = to_dict(raw_rt) if raw_rt is not None else {} + + # Topology-level fields that also live inside engine_args in legacy + # YAMLs (worker_type, scheduler_cls, etc.) — read from both places. + worker_type = stage_data.get("worker_type", None) or yaml_engine_args.pop("worker_type", None) + scheduler_cls = stage_data.get("scheduler_cls", None) or yaml_engine_args.pop("scheduler_cls", None) + hf_config_name = stage_data.get("hf_config_name", None) or yaml_engine_args.pop("hf_config_name", None) + model_stage = getattr(stage_data, "model_stage", None) or yaml_engine_args.pop("model_stage", None) + + # Collect pass-through fields (default_sampling_params, + # output_connectors, input_connectors, tts_args, etc.) + yaml_extras: dict[str, Any] = {} + for key in stage_data: + if key not in cls._KNOWN_STAGE_KEYS: + val = stage_data[key] + try: + yaml_extras[key] = to_dict(val) + except ValueError: + yaml_extras[key] = val + stage = StageConfig( stage_id=stage_data.stage_id, - model_stage=stage_data.model_stage, + model_stage=model_stage or "", stage_type=stage_type, input_sources=input_sources, custom_process_input_func=stage_data.get("custom_process_input_func", None), final_output=stage_data.get("final_output", False), final_output_type=stage_data.get("final_output_type", None), - worker_type=stage_data.get("worker_type", None), - scheduler_cls=stage_data.get("scheduler_cls", None), - hf_config_name=stage_data.get("hf_config_name", None), + worker_type=worker_type, + scheduler_cls=scheduler_cls, + hf_config_name=hf_config_name, is_comprehension=stage_data.get("is_comprehension", False), + yaml_engine_args=yaml_engine_args, + yaml_runtime=yaml_runtime, + yaml_extras=yaml_extras, ) stages.append(stage) # Get pipeline-wide flags async_chunk = config_data.get("async_chunk", False) - # Get optional connector config - connectors = to_dict(config_data.connectors) if hasattr(config_data, "connectors") else None - edges = to_dict(config_data.edges) if hasattr(config_data, "edges") else None + # Get optional connector config — check both top-level and nested + # under ``runtime`` (legacy stage_configs format). + connectors = None + edges = None + if hasattr(config_data, "connectors"): + connectors = to_dict(config_data.connectors) + if hasattr(config_data, "edges"): + edges = to_dict(config_data.edges) + if hasattr(config_data, "runtime") and config_data.runtime is not None: + top_runtime = config_data.runtime + if connectors is None and hasattr(top_runtime, "connectors"): + connectors = to_dict(top_runtime.connectors) + if edges is None and hasattr(top_runtime, "edges"): + edges = to_dict(top_runtime.edges) return ModelPipeline( model_type=getattr(config_data, "model_type", model_type), @@ -421,9 +503,21 @@ def _auto_detect_model_type(cls, model: str, trust_remote_code: bool = True) -> hf_config = get_config(model, trust_remote_code=trust_remote_code) return hf_config.model_type, hf_config + except Exception: + pass + + # Fallback: read config.json directly for custom model types that + # are not registered with transformers (e.g. qwen3_tts). + try: + from vllm.transformers_utils.config import get_hf_file_to_dict + + config_dict = get_hf_file_to_dict("config.json", model, revision=None) + if config_dict and "model_type" in config_dict: + return config_dict["model_type"], None except Exception as e: logger.debug(f"Failed to auto-detect model type for {model}: {e}") - return None, None + + return None, None # Keys that should never be forwarded as engine overrides (internal / # orchestrator-only knobs, complex objects, etc.). diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py new file mode 100644 index 0000000000..85cfd1be4c --- /dev/null +++ b/vllm_omni/diffusion/stage_diffusion_client.py @@ -0,0 +1,146 @@ +"""Stage Diffusion Client for vLLM-Omni multi-stage runtime. + +Wraps AsyncOmniDiffusion to expose the same interface the Orchestrator +expects from any stage client. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger + +from vllm_omni.engine.stage_init_utils import StageMetadata +from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion +from vllm_omni.outputs import OmniRequestOutput + +if TYPE_CHECKING: + from vllm_omni.diffusion.data import OmniDiffusionConfig + from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType + +logger = init_logger(__name__) + + +class StageDiffusionClient: + """Wraps AsyncOmniDiffusion for use inside the Orchestrator. + + Exposes the same attributes and async methods the Orchestrator + uses on StageEngineCoreClient, but routes execution through + DiffusionEngine instead of vLLM EngineCore. + """ + + stage_type: str = "diffusion" + + def __init__( + self, + model: str, + od_config: OmniDiffusionConfig, + metadata: StageMetadata, + ) -> None: + self.stage_id = metadata.stage_id + self.final_output = metadata.final_output + self.final_output_type = metadata.final_output_type + self.default_sampling_params = metadata.default_sampling_params + self.custom_process_input_func = metadata.custom_process_input_func + self.engine_input_source = metadata.engine_input_source + + self._engine = AsyncOmniDiffusion(model=model, od_config=od_config) + self._output_queue: asyncio.Queue[OmniRequestOutput] = asyncio.Queue() + self._tasks: dict[str, asyncio.Task] = {} + + logger.info("[StageDiffusionClient] Stage-%s initialized", self.stage_id) + + async def add_request_async( + self, + request_id: str, + prompt: OmniPromptType, + sampling_params: OmniDiffusionSamplingParams, + ) -> None: + task = asyncio.create_task( + self._run(request_id, prompt, sampling_params), + name=f"diffusion-{request_id}", + ) + self._tasks[request_id] = task + + async def _run( + self, + request_id: str, + prompt: OmniPromptType, + sampling_params: OmniDiffusionSamplingParams, + ) -> None: + try: + result = await self._engine.generate(prompt, sampling_params, request_id) + await self._output_queue.put(result) + except Exception as e: + logger.exception( + "[StageDiffusionClient] Stage-%s req=%s failed: %s", + self.stage_id, + request_id, + e, + ) + finally: + self._tasks.pop(request_id, None) + + def get_diffusion_output_async(self) -> OmniRequestOutput | None: + try: + return self._output_queue.get_nowait() + except asyncio.QueueEmpty: + return None + + async def abort_requests_async(self, request_ids: list[str]) -> None: + for rid in request_ids: + task = self._tasks.pop(rid, None) + if task: + task.cancel() + + async def collective_rpc_async( + self, + method: str, + timeout: float | None = None, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + """Best-effort control RPC shim for diffusion stages. + + TODO(AsyncOmni): add dedicated wrappers on AsyncOmniDiffusion for the + remaining control APIs instead of reaching into its underlying engine. + """ + kwargs = kwargs or {} + + if method in {"add_lora", "remove_lora", "list_loras", "pin_lora", "start_profile", "stop_profile"}: + target = getattr(self._engine, method, None) + if target is None: + return { + "supported": False, + "todo": True, + "reason": f"AsyncOmniDiffusion.{method} is not implemented", + } + result = target(*args, **kwargs) + if timeout is not None: + return await asyncio.wait_for(result, timeout=timeout) + return await result + + if method in {"sleep", "wake_up"}: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + self._engine._executor, + self._engine.engine.collective_rpc, + method, + timeout, + args, + kwargs, + None, + ) + + return { + "supported": False, + "todo": True, + "reason": f"Diffusion stage collective_rpc method {method} is not implemented yet", + } + + def shutdown(self) -> None: + for task in self._tasks.values(): + task.cancel() + self._tasks.clear() + self._engine.close() diff --git a/vllm_omni/distributed/omni_connectors/adapter.py b/vllm_omni/distributed/omni_connectors/adapter.py index 91a0146c9e..79a824aa31 100644 --- a/vllm_omni/distributed/omni_connectors/adapter.py +++ b/vllm_omni/distributed/omni_connectors/adapter.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# temporary for compatibility with vllm_omni.entrypoints.omni_stage.py -# and vllm_omni.entrypoints.omni_llm.py - import time from collections.abc import Callable from typing import Any -from vllm_omni.entrypoints.stage_utils import OmniStageTaskType from vllm_omni.metrics import OrchestratorAggregator from .utils.logging import get_connector_logger @@ -66,7 +62,7 @@ def try_send_via_connector( if success: # Send lightweight notification via queue notify_payload = { - "type": OmniStageTaskType.GENERATE, + "type": "generate", "request_id": req_id, "sampling_params": sampling_params, "from_connector": True, diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index d733e7bc4b..a1c9a05c8b 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -1,8 +1,10 @@ +import argparse +import dataclasses from dataclasses import dataclass, field from typing import Any import vllm.envs as envs -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.arg_utils import EngineArgs from vllm.logger import init_logger from vllm.transformers_utils.gguf_utils import is_gguf @@ -86,167 +88,11 @@ def __post_init__(self) -> None: load_omni_general_plugins() super().__post_init__() - def _ensure_omni_models_registered(self): - if hasattr(self, "_omni_models_registered"): - return True - register_omni_models_to_vllm() - self._omni_models_registered = True - return True - - def create_model_config(self) -> OmniModelConfig: - """Create an OmniModelConfig from these engine arguments. - Returns: - OmniModelConfig instance with all configuration fields set - """ - # GGUF files need a specific model loader path in vLLM. - if is_gguf(self.model): - self.quantization = self.load_format = "gguf" - - if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: - logger.warning( - "The global random seed is set to %d. Since " - "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " - "affect the random state of the Python process that " - "launched vLLM.", - self.seed, - ) - - # register omni models to avoid model not found error - self._ensure_omni_models_registered() - - # Keep compatibility when async args are constructed from partial payloads. - limit_mm_per_prompt = getattr(self, "limit_mm_per_prompt", {}) - language_model_only = getattr(self, "language_model_only", False) - enable_mm_embeds = getattr(self, "enable_mm_embeds", False) - interleave_mm_strings = getattr(self, "interleave_mm_strings", False) - media_io_kwargs = getattr(self, "media_io_kwargs", {}) - skip_mm_profiling = getattr(self, "skip_mm_profiling", False) - mm_processor_kwargs = getattr(self, "mm_processor_kwargs", None) - mm_processor_cache_gb = getattr(self, "mm_processor_cache_gb", 4) - mm_processor_cache_type = getattr(self, "mm_processor_cache_type", None) - mm_shm_cache_max_object_size_mb = getattr(self, "mm_shm_cache_max_object_size_mb", 128) - mm_encoder_only = getattr(self, "mm_encoder_only", False) - mm_encoder_tp_mode = getattr(self, "mm_encoder_tp_mode", "weights") - mm_encoder_attn_backend = getattr(self, "mm_encoder_attn_backend", None) - video_pruning_rate = getattr(self, "video_pruning_rate", 0.0) - - # Build stage_connector_config from stage_connector_spec - stage_connector_config = { - "name": self.stage_connector_spec.get("name", "SharedMemoryConnector"), - "extra": self.stage_connector_spec.get("extra", {}).copy(), - } - stage_connector_config["extra"]["stage_id"] = self.stage_id - - # Create OmniModelConfig directly from engine args - # Note: We pass the actual init parameters matching vLLM's EngineArgs.create_model_config() - omni_config = OmniModelConfig( - # Base ModelConfig fields (matching vLLM's EngineArgs.create_model_config) - model=self.model, - model_weights=self.model_weights, - hf_config_path=self.hf_config_path, - runner=self.runner, - convert=self.convert, - tokenizer=self.tokenizer, - tokenizer_mode=self.tokenizer_mode, - trust_remote_code=self.trust_remote_code, - allowed_local_media_path=self.allowed_local_media_path, - allowed_media_domains=self.allowed_media_domains, - dtype=self.dtype, - seed=self.seed, - revision=self.revision, - code_revision=self.code_revision, - hf_token=self.hf_token, - hf_overrides=self.hf_overrides, - tokenizer_revision=self.tokenizer_revision, - max_model_len=self.max_model_len, - quantization=self.quantization, - allow_deprecated_quantization=self.allow_deprecated_quantization, - enforce_eager=self.enforce_eager, - enable_return_routed_experts=self.enable_return_routed_experts, - max_logprobs=self.max_logprobs, - logprobs_mode=self.logprobs_mode, - disable_sliding_window=self.disable_sliding_window, - disable_cascade_attn=self.disable_cascade_attn, - skip_tokenizer_init=self.skip_tokenizer_init, - enable_prompt_embeds=self.enable_prompt_embeds, - served_model_name=self.served_model_name, - language_model_only=language_model_only, - limit_mm_per_prompt=limit_mm_per_prompt, - enable_mm_embeds=enable_mm_embeds, - interleave_mm_strings=interleave_mm_strings, - media_io_kwargs=media_io_kwargs, - skip_mm_profiling=skip_mm_profiling, - config_format=self.config_format, - mm_processor_kwargs=mm_processor_kwargs, - mm_processor_cache_gb=mm_processor_cache_gb, - mm_processor_cache_type=mm_processor_cache_type, - mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, - mm_encoder_only=mm_encoder_only, - mm_encoder_tp_mode=mm_encoder_tp_mode, - mm_encoder_attn_backend=mm_encoder_attn_backend, - pooler_config=self.pooler_config, - generation_config=self.generation_config, - override_generation_config=self.override_generation_config, - enable_sleep_mode=self.enable_sleep_mode, - model_impl=self.model_impl, - override_attention_dtype=self.override_attention_dtype, - logits_processors=self.logits_processors, - video_pruning_rate=video_pruning_rate, - io_processor_plugin=self.io_processor_plugin, - # Omni-specific fields - stage_id=self.stage_id, - async_chunk=self.async_chunk, - model_stage=self.model_stage, - model_arch=self.model_arch, - worker_type=self.worker_type, - engine_output_type=self.engine_output_type, - hf_config_name=self.hf_config_name, - custom_process_next_stage_input_func=self.custom_process_next_stage_input_func, - stage_connector_config=stage_connector_config, - omni_kv_config=self.omni_kv_config, - task_type=self.task_type, - ) - omni_config.hf_config.architectures = omni_config.architectures - - return omni_config - - -@dataclass -class AsyncOmniEngineArgs(AsyncEngineArgs): - """Async engine arguments for omni models, extending base AsyncEngineArgs. - Adds omni-specific configuration fields for multi-stage pipeline - processing and output type specification in async contexts. - Args: - stage_id: Identifier for the stage in a multi-stage pipeline (default: 0) - model_stage: Stage type identifier, e.g., "thinker" or "talker" - (default: "thinker") - model_arch: Model architecture name - (default: "Qwen2_5OmniForConditionalGeneration") - engine_output_type: Optional output type specification for the engine. - Used to route outputs to appropriate processors (e.g., "image", - "audio", "latents"). If None, output type is inferred. - stage_connector_spec: Extra configuration for stage connector - worker_type: Model Type, e.g., "ar" or "generation" - task_type: Default task type for TTS models (CustomVoice, VoiceDesign, or Base). - If not specified, will be inferred from model path. - """ - - stage_id: int = 0 - model_stage: str = "thinker" - model_arch: str | None = None - engine_output_type: str | None = None - hf_config_name: str | None = None - custom_process_next_stage_input_func: str | None = None - stage_connector_spec: dict[str, Any] = field(default_factory=dict) - async_chunk: bool = False - omni_kv_config: dict | None = None - quantization_config: Any | None = None - worker_type: str | None = None - task_type: str | None = None - - def __post_init__(self) -> None: - load_omni_general_plugins() - super().__post_init__() + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "OmniEngineArgs": + attrs = [attr.name for attr in dataclasses.fields(cls)] + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}) + return engine_args def _ensure_omni_models_registered(self): if hasattr(self, "_omni_models_registered"): @@ -277,8 +123,8 @@ def create_model_config(self) -> OmniModelConfig: self._ensure_omni_models_registered() # Keep compatibility when async args are constructed from partial payloads. - limit_mm_per_prompt = getattr(self, "limit_mm_per_prompt", {}) language_model_only = getattr(self, "language_model_only", False) + limit_mm_per_prompt = getattr(self, "limit_mm_per_prompt", {}) enable_mm_embeds = getattr(self, "enable_mm_embeds", False) interleave_mm_strings = getattr(self, "interleave_mm_strings", False) media_io_kwargs = getattr(self, "media_io_kwargs", {}) diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py new file mode 100644 index 0000000000..4405cb2eb4 --- /dev/null +++ b/vllm_omni/engine/async_omni_engine.py @@ -0,0 +1,1105 @@ +""" +Async Omni Engine for vLLM-Omni multi-stage runtime. + +AsyncOmniEngine in the caller's thread is a thin proxy that communicates +with the Orchestrator (running in a background thread) via janus queues. +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import dataclasses +import json +import os +import queue +import threading +import time +import uuid +import weakref +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from vllm_omni.engine.arg_utils import OmniEngineArgs + +import janus +import torch +from omegaconf import OmegaConf +from vllm.inputs import PromptType +from vllm.logger import init_logger +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.input_processor import InputProcessor +from vllm.v1.engine.utils import get_engine_zmq_addresses, launch_core_engines + +from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.distributed.omni_connectors.utils.initialization import ( + resolve_omni_kv_config_for_stage, +) +from vllm_omni.engine import ( + OmniEngineCoreRequest, +) +from vllm_omni.engine.orchestrator import Orchestrator +from vllm_omni.engine.output_processor import MultimodalOutputProcessor +from vllm_omni.engine.serialization import serialize_additional_information +from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient +from vllm_omni.engine.stage_init_utils import ( + StartedLlmStage, + acquire_device_locks, + build_engine_args_dict, + build_vllm_config, + cleanup_failed_stage_initialization, + close_started_llm_stage, + extract_stage_metadata, + finalize_initialized_stages, + get_stage_connector_spec, + initialize_diffusion_stage, + load_omni_transfer_config_for_model, + prepare_engine_environment, + release_device_locks, + setup_stage_devices, +) +from vllm_omni.entrypoints.utils import ( + load_and_resolve_stage_configs, +) + +logger = init_logger(__name__) + + +def _inject_kv_stage_info(stage_cfg: Any, stage_id: int) -> None: + """Inject stage_id and engine_input_source into omni_kv_config. + + OmniKVTransferManager needs stage_id to compute recv_stages for the + receiving side. In the old Omni architecture, OmniDiffusion.__init__ + performed this injection; replicate it here for AsyncOmniEngine. + """ + try: + engine_args = stage_cfg.engine_args + if hasattr(engine_args, "get"): + omni_kv = engine_args.get("omni_kv_config", None) + else: + omni_kv = getattr(engine_args, "omni_kv_config", None) + + if omni_kv is None: + return + + if hasattr(omni_kv, "setdefault"): + omni_kv.setdefault("stage_id", stage_id) + elif hasattr(omni_kv, "__setitem__"): + if "stage_id" not in omni_kv: + omni_kv["stage_id"] = stage_id + + engine_input_source = getattr(stage_cfg, "engine_input_source", None) + if engine_input_source is not None: + if hasattr(omni_kv, "setdefault"): + omni_kv.setdefault("engine_input_source", list(engine_input_source)) + elif hasattr(omni_kv, "__setitem__") and "engine_input_source" not in omni_kv: + omni_kv["engine_input_source"] = list(engine_input_source) + except Exception as e: + logger.debug("Failed to inject stage info into omni_kv_config: %s", e) + + +def _inject_global_id(target: Any, request_id: str) -> None: + """Inject global_request_id into a prompt dict's additional_information.""" + if isinstance(target, dict): + if "additional_information" not in target: + target["additional_information"] = {} + if target["additional_information"] is None: + target["additional_information"] = {} + if isinstance(target["additional_information"], dict): + target["additional_information"]["global_request_id"] = [str(request_id)] + + +def _upgrade_to_omni_request( + request: EngineCoreRequest, + raw_prompt: Any, +) -> EngineCoreRequest: + """Restore omni-only fields omitted by upstream InputProcessor.""" + prompt_embeds = request.prompt_embeds + additional_information = None + + if isinstance(raw_prompt, dict): + if prompt_embeds is None: + raw_prompt_embeds = raw_prompt.get("prompt_embeds") + if isinstance(raw_prompt_embeds, torch.Tensor): + prompt_embeds = raw_prompt_embeds + additional_information = serialize_additional_information( + raw_prompt.get("additional_information"), + log_prefix="AsyncOmniEngine", + ) + + if prompt_embeds is None and additional_information is None: + return request + + return OmniEngineCoreRequest( + request_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + mm_features=request.mm_features, + sampling_params=request.sampling_params, + pooling_params=request.pooling_params, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + cache_salt=request.cache_salt, + data_parallel_rank=request.data_parallel_rank, + prompt_embeds=prompt_embeds, + client_index=request.client_index, + current_wave=request.current_wave, + priority=request.priority, + trace_headers=request.trace_headers, + resumable=request.resumable, + external_req_id=request.external_req_id, + reasoning_ended=request.reasoning_ended, + additional_information=additional_information, + ) + + +def _weak_shutdown_async_omni_engine( + orchestrator_thread: threading.Thread | None, + request_queue: janus.Queue[dict[str, Any]] | None, + output_queue: janus.Queue[dict[str, Any]] | None, + rpc_output_queue: janus.Queue[dict[str, Any]] | None, +) -> None: + """Best-effort orchestrator cleanup for GC finalization.""" + try: + if request_queue is not None: + request_queue.sync_q.put_nowait({"type": "shutdown"}) + except Exception: + pass + + try: + if orchestrator_thread is not None and orchestrator_thread.is_alive(): + orchestrator_thread.join(timeout=10) + except Exception: + pass + + for q in (request_queue, output_queue, rpc_output_queue): + if q is None: + continue + try: + q.close() + except Exception: + pass + + +class AsyncOmniEngine: + """Thin proxy that launches an Orchestrator in a background thread. + + All stage clients, input/output processors, and stage-to-stage transfer + logic live inside the Orchestrator coroutine (running in its own thread + with a dedicated asyncio event loop). This class communicates with it + via janus queues (sync side for callers, async side for orchestrator). + + Args: + model: Model name or path + init_timeout: Total timeout waiting for orchestrator startup (seconds). + stage_init_timeout: Timeout for stage initialization (seconds) + **kwargs: Additional arguments + """ + + def __init__( + self, + model: str, + engine_args: OmniEngineArgs | None = None, + stage_init_timeout: int = 300, + init_timeout: int = 600, + **kwargs: Any, + ) -> None: + self.model = model + startup_timeout = int(init_timeout) + + logger.info(f"[AsyncOmniEngine] Initializing with model {model}") + + # Merge typed engine_args fields into kwargs; explicit kwargs take priority. + if engine_args is not None: + ea_dict = { + f.name: getattr(engine_args, f.name) + for f in dataclasses.fields(engine_args) + if not f.name.startswith("_") + } + # Remove model since it is passed as a positional arg already. + ea_dict.pop("model", None) + kwargs = {**ea_dict, **kwargs} + + self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs) + + self.num_stages = len(self.stage_configs) + stage0_args = getattr(self.stage_configs[0], "engine_args", None) if self.num_stages > 0 else None + self.async_chunk = bool(getattr(stage0_args, "async_chunk", False)) + self.stage_clients: list[Any] = [] + self.stage_vllm_configs: list[Any] = [] + self.output_processors: list[MultimodalOutputProcessor | None] = [] + self.input_processor: InputProcessor | None = None + self.supported_tasks: tuple[str, ...] = ("generate",) + self.default_sampling_params_list: list[Any] = [] + self.stage_metadata: list[dict[str, Any]] = [] + self.request_queue: janus.Queue[dict[str, Any]] | None = None + self.output_queue: janus.Queue[dict[str, Any]] | None = None + self.rpc_output_queue: janus.Queue[dict[str, Any]] | None = None + self._shutdown_called = False + self._weak_finalizer: weakref.finalize | None = None + self._rpc_lock = threading.Lock() + + logger.info(f"[AsyncOmniEngine] Launching Orchestrator thread with {self.num_stages} stages") + + # Launch orchestrator background thread + startup_future: concurrent.futures.Future = concurrent.futures.Future() + + self.orchestrator_thread = threading.Thread( + target=self._bootstrap_orchestrator, + args=( + stage_init_timeout, + startup_future, + ), + daemon=True, + name="orchestrator", + ) + self.orchestrator_thread.start() + + # Wait for stage/runtime initialization result from orchestrator thread. + try: + startup_future.result(timeout=startup_timeout) + except concurrent.futures.TimeoutError as e: + try: + self.shutdown() + except Exception: + logger.exception("[AsyncOmniEngine] Failed to cleanup after orchestrator startup timeout") + raise TimeoutError(f"Orchestrator did not become ready within {startup_timeout}s") from e + except Exception: + try: + self.shutdown() + except Exception: + logger.exception("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure") + raise + + # Stage runtime fields are assigned directly on self by the bootstrap thread. + self._weak_finalizer = weakref.finalize( + self, + _weak_shutdown_async_omni_engine, + self.orchestrator_thread, + self.request_queue, + self.output_queue, + self.rpc_output_queue, + ) + + logger.info(f"[AsyncOmniEngine] Orchestrator ready with {self.num_stages} stages") + + def _launch_llm_stage( + self, + stage_cfg: Any, + metadata: Any, + stage_connector_spec: dict[str, Any], + stage_init_timeout: int, + llm_stage_launch_lock: threading.Lock, + omni_kv_connector: tuple[dict[str, Any] | None, str | None, str | None] = (None, None, None), + ) -> StartedLlmStage: + """Launch one LLM stage to READY state in a helper thread.""" + from vllm_omni.platforms import current_omni_platform + + started_stage: StartedLlmStage | None = None + lock_fds: list[int] = [] + device_control_env = current_omni_platform.device_control_env_var + + try: + with llm_stage_launch_lock: + previous_visible_devices = os.environ.get(device_control_env) + try: + setup_stage_devices(metadata.stage_id, metadata.runtime_cfg) + engine_args_dict = build_engine_args_dict( + stage_cfg, + self.model, + stage_connector_spec=stage_connector_spec, + ) + omni_conn_cfg, omni_from, omni_to = omni_kv_connector + if omni_conn_cfg: + omni_kv = engine_args_dict.get("omni_kv_config") or {} + if not isinstance(omni_kv, dict): + omni_kv = dict(omni_kv) + omni_kv["connector_config"] = omni_conn_cfg + omni_kv["omni_from_stage"] = omni_from + omni_kv["omni_to_stage"] = omni_to + omni_kv.setdefault("stage_id", metadata.stage_id) + engine_args_dict["omni_kv_config"] = omni_kv + vllm_config, executor_class = build_vllm_config( + stage_cfg, + self.model, + stage_connector_spec=stage_connector_spec, + engine_args_dict=engine_args_dict, + ) + lock_fds = acquire_device_locks( + metadata.stage_id, + engine_args_dict, + stage_init_timeout, + ) + addresses = get_engine_zmq_addresses(vllm_config) + launch_cm = launch_core_engines( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + addresses=addresses, + ) + engine_manager, coordinator, addresses = launch_cm.__enter__() + started_stage = StartedLlmStage( + stage_id=metadata.stage_id, + metadata=metadata, + vllm_config=vllm_config, + executor_class=executor_class, + engine_manager=engine_manager, + coordinator=coordinator, + addresses=addresses, + ) + finally: + if previous_visible_devices is None: + os.environ.pop(device_control_env, None) + else: + os.environ[device_control_env] = previous_visible_devices + + logger.info("[AsyncOmniEngine] Stage %s engine launch started", metadata.stage_id) + launch_cm.__exit__(None, None, None) + logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id) + assert started_stage is not None + return started_stage + except Exception: + if started_stage is not None: + close_started_llm_stage(started_stage) + raise + finally: + if lock_fds: + release_device_locks(lock_fds) + + def _attach_llm_stage( + self, + started: StartedLlmStage, + ) -> tuple[Any, Any, Any, InputProcessor | None]: + """Attach a READY LLM stage to the orchestrator event loop.""" + + client_addresses = { + "input_address": started.addresses.inputs[0], + "output_address": started.addresses.outputs[0], + } + if started.addresses.frontend_stats_publish_address is not None: + client_addresses["stats_update_address"] = started.addresses.frontend_stats_publish_address + + try: + stage_client = StageEngineCoreClient( + vllm_config=started.vllm_config, + executor_class=started.executor_class, + metadata=started.metadata, + client_addresses=client_addresses, + engine_manager=started.engine_manager, + coordinator=started.coordinator, + ) + started.engine_manager = None + started.coordinator = None + except Exception: + close_started_llm_stage(started) + raise + + try: + if started.vllm_config.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = cached_tokenizer_from_config( + model_config=started.vllm_config.model_config, + ) + output_processor = MultimodalOutputProcessor( + tokenizer=tokenizer, + log_stats=False, + engine_core_output_type=started.metadata.engine_output_type, + ) + input_processor = None + if started.stage_id == 0: + input_processor = InputProcessor(vllm_config=started.vllm_config) + except Exception: + try: + stage_client.shutdown() + except Exception as cleanup_error: + logger.warning( + "[AsyncOmniEngine] Failed to cleanup stage %s after attach failure: %s", + started.stage_id, + cleanup_error, + ) + raise + + logger.info("[AsyncOmniEngine] Stage %s initialized", started.stage_id) + return stage_client, output_processor, started.vllm_config, input_processor + + def _initialize_stages(self, stage_init_timeout: int) -> None: + """Initialize stage clients/processors in orchestrator thread and assign to self.""" + + num_stages = self.num_stages + stage_clients: list[Any | None] = [None] * num_stages + output_processors: list[Any | None] = [None] * num_stages + stage_vllm_configs: list[Any | None] = [None] * num_stages + input_processor: InputProcessor | None = None + llm_stage_ids: list[int] = [] + llm_launch_futures: dict[int, concurrent.futures.Future[StartedLlmStage]] = {} + started_llm_stages: dict[int, StartedLlmStage] = {} + llm_stage_launch_lock = threading.Lock() + + async_chunk = self.async_chunk + prompt_expand_func = None + llm_stage_count = sum( + 1 for stage_cfg in self.stage_configs if getattr(stage_cfg, "stage_type", "llm") != "diffusion" + ) + + prepare_engine_environment() + omni_transfer_config = load_omni_transfer_config_for_model(self.model, self.config_path) + + try: + with concurrent.futures.ThreadPoolExecutor( + max_workers=max(1, llm_stage_count), + thread_name_prefix="llm-stage-launch", + ) as launch_executor: + for stage_id, stage_cfg in enumerate(self.stage_configs): + logger.info("[AsyncOmniEngine] Initializing stage %s", stage_id) + metadata = extract_stage_metadata(stage_cfg) + if metadata.prompt_expand_func is not None: + prompt_expand_func = metadata.prompt_expand_func + + stage_connector_spec = get_stage_connector_spec( + omni_transfer_config=omni_transfer_config, + stage_id=stage_id, + async_chunk=async_chunk, + ) + + omni_kv_connector = resolve_omni_kv_config_for_stage(omni_transfer_config, stage_id) + + if metadata.stage_type == "diffusion": + setup_stage_devices(stage_id, metadata.runtime_cfg) + omni_conn_cfg, omni_from, omni_to = omni_kv_connector + if omni_conn_cfg: + from vllm_omni.entrypoints.utils import inject_omni_kv_config + + inject_omni_kv_config(stage_cfg, omni_conn_cfg, omni_from, omni_to) + _inject_kv_stage_info(stage_cfg, stage_id) + stage_clients[stage_id] = initialize_diffusion_stage(self.model, stage_cfg, metadata) + logger.info("[AsyncOmniEngine] Stage %s initialized (diffusion)", stage_id) + continue + + llm_stage_ids.append(stage_id) + llm_launch_futures[stage_id] = launch_executor.submit( + self._launch_llm_stage, + stage_cfg, + metadata, + stage_connector_spec, + stage_init_timeout, + llm_stage_launch_lock, + omni_kv_connector, + ) + + concurrent.futures.wait(list(llm_launch_futures.values())) + + for stage_id in llm_stage_ids: + started_llm_stages[stage_id] = llm_launch_futures[stage_id].result() + + for stage_id in llm_stage_ids: + started = started_llm_stages[stage_id] + stage_client, output_processor, vllm_config, stage0_input_processor = self._attach_llm_stage(started) + stage_clients[stage_id] = stage_client + output_processors[stage_id] = output_processor + stage_vllm_configs[stage_id] = vllm_config + if stage0_input_processor is not None: + input_processor = stage0_input_processor + + initialized_stage_clients, default_sampling_params_list, stage_metadata = finalize_initialized_stages( + stage_clients, + input_processor, + ) + except Exception: + for stage_id, future in llm_launch_futures.items(): + if not future.done() or future.cancelled() or future.exception() is not None: + continue + started_llm_stages.setdefault(stage_id, future.result()) + logger.exception( + "[AsyncOmniEngine] Stage initialization failed; shutting down %s initialized stage(s)", + len([stage_client for stage_client in stage_clients if stage_client is not None]), + ) + cleanup_failed_stage_initialization( + stage_clients, + [started_llm_stages[stage_id] for stage_id in llm_stage_ids if stage_id in started_llm_stages], + ) + raise + + self.stage_clients = initialized_stage_clients + self.output_processors = output_processors + self.stage_vllm_configs = stage_vllm_configs + self.input_processor = input_processor + self.prompt_expand_func = prompt_expand_func + # TODO(Peiqi): Hack here + supported_tasks: set[str] = set() + if any(getattr(stage_client, "is_comprehension", False) for stage_client in initialized_stage_clients): + supported_tasks.add("generate") + if any(metadata.get("final_output_type") == "audio" for metadata in stage_metadata): + supported_tasks.add("speech") + self.supported_tasks = tuple(supported_tasks) if supported_tasks else ("generate",) + + self.default_sampling_params_list = default_sampling_params_list + self.stage_metadata = stage_metadata + + def _initialize_janus_queues(self) -> None: + """Initialize janus queues inside orchestrator thread loop context.""" + self.request_queue = janus.Queue() + self.output_queue = janus.Queue() + self.rpc_output_queue = janus.Queue() + logger.debug("[AsyncOmniEngine] janus queues initialized in orchestrator thread loop") + + def _bootstrap_orchestrator( + self, + stage_init_timeout: int, + startup_future: concurrent.futures.Future, + ) -> None: + """Create loop, initialize stages, then run Orchestrator.""" + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def _run_orchestrator() -> None: + self._initialize_janus_queues() + + self._initialize_stages(stage_init_timeout) + orchestrator = Orchestrator( + request_async_queue=self.request_queue.async_q, + output_async_queue=self.output_queue.async_q, + rpc_async_queue=self.rpc_output_queue.async_q, + async_chunk=self.async_chunk, + stage_clients=self.stage_clients, + output_processors=self.output_processors, + stage_vllm_configs=self.stage_vllm_configs, + ) + if not startup_future.done(): + startup_future.set_result(asyncio.get_running_loop()) + await orchestrator.run() + + try: + loop.run_until_complete(_run_orchestrator()) + except Exception as e: + if not startup_future.done(): + startup_future.set_exception(RuntimeError(f"Orchestrator initialization failed: {e}")) + logger.exception("[AsyncOmniEngine] Orchestrator thread crashed") + try: + if self.output_queue is not None: + self.output_queue.sync_q.put_nowait({"type": "error", "error": "Orchestrator thread crashed"}) + if self.rpc_output_queue is not None: + self.rpc_output_queue.sync_q.put_nowait({"type": "error", "error": "Orchestrator thread crashed"}) + except Exception: + pass + raise + finally: + try: + pending = [task for task in asyncio.all_tasks(loop) if not task.done()] + for task in pending: + task.cancel() + if pending: + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, "shutdown_default_executor"): + loop.run_until_complete(loop.shutdown_default_executor()) + except Exception: + logger.exception("[AsyncOmniEngine] Failed during orchestrator loop cleanup") + finally: + asyncio.set_event_loop(None) + loop.close() + + # ---- request helpers ---- + + def _build_add_request_message( + self, + request_id: str, + prompt: EngineCoreRequest | PromptType, + sampling_params_list: Sequence[Any] | None = None, + final_stage_id: int = 0, + arrival_time: float | None = None, + ) -> dict[str, Any]: + """Build an add_request message after stage-0 preprocessing.""" + effective_sampling_params_list = ( + list(sampling_params_list) if sampling_params_list is not None else list(self.default_sampling_params_list) + ) + if not effective_sampling_params_list: + raise ValueError( + f"Missing sampling params for stage 0. Got {len(effective_sampling_params_list)} stage params." + ) + params = effective_sampling_params_list[0] + + # Keep the original prompt for downstream stages (they need the raw + # dict, e.g. for multi_modal_data). + original_prompt = prompt + + stage_type = self.stage_metadata[0].get("stage_type") + if stage_type != "diffusion" and not isinstance(prompt, EngineCoreRequest): + # Inject global_request_id into the raw prompt. + if isinstance(prompt, dict): + _inject_global_id(prompt, request_id) + elif isinstance(prompt, list): + for item in prompt: + _inject_global_id(item, request_id) + + # Full input processing (tokenization, multimodal, etc.) + request = self.input_processor.process_inputs( + request_id=request_id, + prompt=prompt, + params=params, + supported_tasks=self.supported_tasks, + arrival_time=arrival_time, + ) + # TODO (Peiqi): add this for Qwen3-TTS only. Other models don't have + # additional_information field in the prompt. + request = _upgrade_to_omni_request(request, prompt) + + # Restore external_req_id to the original user-facing request_id. + # InputProcessor.process_inputs() renames request_id to an internal + # UUID (saving the original in external_req_id), but then overwrites + # external_req_id with the new internal ID. We need external_req_id + # to match the key used in Orchestrator.request_states so that + # output routing (output.request_id lookup) can find the req_state. + request.external_req_id = request_id + + # Register with stage 0's output processor. + self.output_processors[0].add_request( + request=request, + prompt=prompt, + parent_req=None, + request_index=0, + queue=None, + ) + prompt = request + + return { + "type": "add_request", + "request_id": request_id, + "prompt": prompt, + "original_prompt": original_prompt, + "sampling_params_list": effective_sampling_params_list, + "final_stage_id": final_stage_id, + } + + def _enqueue_cfg_companions( + self, + parent_id: str, + original_prompt: Any, + stage0_params: Any, + sampling_params_list: list[Any], + ) -> None: + """Expand prompt into CFG companions, process through InputProcessor, and enqueue.""" + try: + expanded = self.prompt_expand_func(original_prompt, stage0_params) + except Exception: + logger.exception("[AsyncOmniEngine] prompt_expand_func failed for req %s", parent_id) + return + + if not expanded: + return + + for ep in expanded: + cid = f"{parent_id}{ep.request_id_suffix}" + companion_prompt = ep.prompt + + # Run through same input processing as the main prompt + if isinstance(companion_prompt, dict): + _inject_global_id(companion_prompt, cid) + + request = self.input_processor.process_inputs( + request_id=cid, + prompt=companion_prompt, + params=stage0_params, + supported_tasks=self.supported_tasks, + ) + request = _upgrade_to_omni_request(request, companion_prompt) + request.external_req_id = cid + + self.output_processors[0].add_request( + request=request, + prompt=companion_prompt, + parent_req=None, + request_index=0, + queue=None, + ) + + self.request_queue.sync_q.put_nowait( + { + "type": "add_companion_request", + "companion_id": cid, + "parent_id": parent_id, + "role": ep.role, + "prompt": request, + "sampling_params_list": sampling_params_list, + } + ) + + logger.info( + "[AsyncOmniEngine] CFG expansion for req %s: %d companions", + parent_id, + len(expanded), + ) + + @staticmethod + def _get_default_cache_config(cache_backend: str | None) -> dict[str, Any] | None: + if cache_backend == "cache_dit": + return { + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + "residual_diff_threshold": 0.24, + "max_continuous_cached_steps": 3, + "enable_taylorseer": False, + "taylorseer_order": 1, + "scm_steps_mask_policy": None, + "scm_steps_policy": "dynamic", + } + if cache_backend == "tea_cache": + return { + "rel_l1_thresh": 0.2, + } + return None + + @staticmethod + def _normalize_cache_config(cache_backend: str | None, cache_config: Any | None) -> Any | None: + if isinstance(cache_config, str): + try: + cache_config = json.loads(cache_config) + except json.JSONDecodeError: + logger.warning("Invalid cache_config JSON, using defaults.") + cache_config = None + if cache_config is None and cache_backend not in (None, "", "none"): + cache_config = AsyncOmniEngine._get_default_cache_config(cache_backend) + return cache_config + + @staticmethod + def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: + """Create a default single-stage diffusion config from kwargs.""" + # We temporally create a default config for diffusion stage. + # In the future, we should merge the default config with the user-provided config. + normalized_kwargs = dict(kwargs) + + # TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error. + if "dtype" in normalized_kwargs and not isinstance(normalized_kwargs["dtype"], str): + if not isinstance(normalized_kwargs["dtype"], torch.dtype): + raise TypeError( + f"Provided dtype must be a string or torch.dtype, got {type(normalized_kwargs['dtype']).__name__}" + ) + normalized_kwargs["dtype"] = str(normalized_kwargs["dtype"]).removeprefix("torch.") + + cache_backend = normalized_kwargs.get("cache_backend", "none") + cache_config = AsyncOmniEngine._normalize_cache_config( + cache_backend, + normalized_kwargs.get("cache_config", None), + ) + + parallel_config = normalized_kwargs.get("parallel_config") + if isinstance(parallel_config, dict): + parallel_config = DiffusionParallelConfig.from_dict(parallel_config) + if parallel_config is None: + ulysses_degree = normalized_kwargs.get("ulysses_degree") or 1 + ring_degree = normalized_kwargs.get("ring_degree") or 1 + sequence_parallel_size = normalized_kwargs.get("sequence_parallel_size") + tensor_parallel_size = normalized_kwargs.get("tensor_parallel_size") or 1 + cfg_parallel_size = normalized_kwargs.get("cfg_parallel_size") or 1 + vae_patch_parallel_size = normalized_kwargs.get("vae_patch_parallel_size") or 1 + use_hsdp = normalized_kwargs.get("use_hsdp", False) + hsdp_shard_size = normalized_kwargs.get("hsdp_shard_size", -1) + hsdp_replicate_size = normalized_kwargs.get("hsdp_replicate_size", 1) + if sequence_parallel_size is None: + sequence_parallel_size = ulysses_degree * ring_degree + + parallel_config = DiffusionParallelConfig( + pipeline_parallel_size=1, + data_parallel_size=1, + tensor_parallel_size=tensor_parallel_size, + sequence_parallel_size=sequence_parallel_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + cfg_parallel_size=cfg_parallel_size, + vae_patch_parallel_size=vae_patch_parallel_size, + use_hsdp=use_hsdp, + hsdp_shard_size=hsdp_shard_size, + hsdp_replicate_size=hsdp_replicate_size, + ) + + num_devices = max(1, int(parallel_config.world_size)) + devices = ",".join(str(i) for i in range(num_devices)) + + default_stage_cfg = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": { + "process": True, + "devices": devices, + "max_batch_size": 1, + }, + "engine_args": { + "parallel_config": parallel_config, + "model_class_name": kwargs.get("model_class_name", None), + "vae_use_slicing": kwargs.get("vae_use_slicing", False), + "vae_use_tiling": kwargs.get("vae_use_tiling", False), + "cache_backend": cache_backend, + "cache_config": cache_config, + "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False), + "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), + "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False), + "enforce_eager": kwargs.get("enforce_eager", False), + "diffusion_load_format": kwargs.get("diffusion_load_format", "default"), + "custom_pipeline_args": kwargs.get("custom_pipeline_args", None), + "worker_extension_cls": kwargs.get("worker_extension_cls", None), + "enable_sleep_mode": kwargs.get("enable_sleep_mode", False), + "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True), + "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4), + }, + "final_output": True, + "final_output_type": "image", + } + ] + default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" + return default_stage_cfg + + def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[str, list[Any]]: + """Resolve stage configs and inject defaults shared by orchestrator/headless.""" + + stage_configs_path = kwargs.get("stage_configs_path", None) + explicit_stage_configs = kwargs.pop("stage_configs", None) + if explicit_stage_configs is not None: + logger.warning( + "`stage_configs` is not part of the public API. " + "Ignoring it and resolving stages from stage_configs_path/model factory." + ) + + # Use the legacy config loading path (load_and_resolve_stage_configs). + # StageConfigFactory wiring will be done in config refactor [2/N]. + config_path, stage_configs = load_and_resolve_stage_configs( + model, + stage_configs_path, + kwargs, + default_stage_cfg_factory=lambda: self._create_default_diffusion_stage_cfg(kwargs), + ) + + # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. + for cfg in stage_configs: + try: + if getattr(cfg, "stage_type", None) != "diffusion": + continue + if not hasattr(cfg, "engine_args") or cfg.engine_args is None: + cfg.engine_args = OmegaConf.create({}) + if kwargs.get("lora_path") is not None: + if not hasattr(cfg.engine_args, "lora_path") or cfg.engine_args.lora_path is None: + cfg.engine_args.lora_path = kwargs["lora_path"] + lora_scale = kwargs.get("lora_scale") + if lora_scale is None: + # Backwards compatibility for older callers. + lora_scale = kwargs.get("static_lora_scale") + if lora_scale is not None: + if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: + cfg.engine_args.lora_scale = lora_scale + quantization_config = kwargs.get("quantization_config") + if quantization_config is not None: + if ( + not hasattr(cfg.engine_args, "quantization_config") + or cfg.engine_args.quantization_config is None + ): + cfg.engine_args.quantization_config = quantization_config + except Exception as e: + logger.warning("Failed to inject LoRA config for stage: %s", e) + + return config_path, stage_configs + + # ==================== Public API ==================== + + def add_request( + self, + request_id: str, + prompt: EngineCoreRequest | PromptType, + sampling_params_list: Sequence[Any] | None = None, + final_stage_id: int = 0, + arrival_time: float | None = None, + ) -> None: + """Process stage-0 input locally, then send to the Orchestrator. + + Input processing and output + processor registration happen here in the caller's thread, avoiding + a queue + coroutine-switch round-trip. The Orchestrator receives a + ready-to-submit OmniEngineCoreRequest. + """ + msg = self._build_add_request_message( + request_id=request_id, + prompt=prompt, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id, + arrival_time=arrival_time, + ) + if self.request_queue is None: + raise RuntimeError("request_queue is not initialized") + self.request_queue.sync_q.put_nowait(msg) + + # CFG companion expansion: create and enqueue companion requests + # so the AR stage also generates their KV caches. + if self.prompt_expand_func is not None and final_stage_id > 0: + original_prompt = msg.get("original_prompt", prompt) + effective_spl = msg.get("sampling_params_list", []) + stage0_params = effective_spl[0] if effective_spl else None + if stage0_params is not None: + self._enqueue_cfg_companions(request_id, original_prompt, stage0_params, effective_spl) + + async def add_request_async( + self, + request_id: str, + prompt: EngineCoreRequest | PromptType, + sampling_params_list: Sequence[Any] | None = None, + final_stage_id: int = 0, + arrival_time: float | None = None, + ) -> None: + """Async add_request API.""" + self.add_request( + request_id=request_id, + prompt=prompt, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id, + arrival_time=arrival_time, + ) + + def try_get_output(self, timeout: float = 0.001) -> dict[str, Any] | None: + """Read one output message from the Orchestrator output queue.""" + if self.output_queue is None: + return None + try: + return self.output_queue.sync_q.get(timeout=timeout) + except queue.Empty: + return None + + async def try_get_output_async(self) -> dict[str, Any] | None: + """Async read from the Orchestrator output queue.""" + if self.output_queue is None: + return None + try: + return self.output_queue.sync_q.get_nowait() + except queue.Empty: + return None + + def get_stage_metadata(self, stage_id: int) -> dict[str, Any]: + """Get cached metadata for a stage.""" + return self.stage_metadata[stage_id] + + def abort(self, request_ids: list[str]) -> None: + """Send abort message to the Orchestrator.""" + if self.request_queue is None: + raise RuntimeError("request_queue is not initialized") + self.request_queue.sync_q.put_nowait( + { + "type": "abort", + "request_ids": request_ids, + } + ) + + async def abort_async(self, request_ids: list[str]) -> None: + """Async abort API.""" + self.abort(request_ids) + + def collective_rpc( + self, + method: str, + timeout: float | None = None, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + stage_ids: list[int] | None = None, + ) -> list[Any]: + """Send a control RPC to the Orchestrator and wait for aggregated results. + + This uses a dedicated RPC output queue so control-plane messages do not + race with the normal request output polling loop. + """ + if self.request_queue is None: + raise RuntimeError("request_queue is not initialized") + if self.rpc_output_queue is None: + raise RuntimeError("rpc_output_queue is not initialized") + + rpc_id = uuid.uuid4().hex + msg = { + "type": "collective_rpc", + "rpc_id": rpc_id, + "method": method, + "args": tuple(args), + "kwargs": kwargs or {}, + "stage_ids": stage_ids, + } + + with self._rpc_lock: + self.request_queue.sync_q.put_nowait(msg) + deadline = None if timeout is None else time.monotonic() + timeout + + while True: + remaining = None if deadline is None else max(0.0, deadline - time.monotonic()) + try: + result_msg = self.rpc_output_queue.sync_q.get(timeout=remaining) + except queue.Empty as exc: + raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") from exc + + if result_msg.get("type") == "error": + raise RuntimeError(result_msg.get("error", "Orchestrator returned an error message")) + + if result_msg.get("type") != "collective_rpc_result": + logger.warning( + "[AsyncOmniEngine] Dropping unexpected rpc queue message type=%s", + result_msg.get("type"), + ) + continue + + if result_msg.get("rpc_id") != rpc_id: + logger.warning( + "[AsyncOmniEngine] Dropping mismatched rpc result rpc_id=%s expected=%s", + result_msg.get("rpc_id"), + rpc_id, + ) + continue + + return list(result_msg.get("results", [])) + + async def collective_rpc_async( + self, + method: str, + timeout: float | None = None, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + stage_ids: list[int] | None = None, + ) -> list[Any]: + """Async wrapper around collective_rpc().""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + lambda: self.collective_rpc( + method=method, + timeout=timeout, + args=args, + kwargs=kwargs, + stage_ids=stage_ids, + ), + ) + + def is_alive(self) -> bool: + """Whether the orchestrator thread is alive.""" + return bool(self.orchestrator_thread.is_alive()) + + def shutdown(self) -> None: + """Send shutdown message and wait for the Orchestrator thread to exit.""" + if getattr(self, "_shutdown_called", False): + return + self._shutdown_called = True + finalizer = getattr(self, "_weak_finalizer", None) + if finalizer is not None and finalizer.alive: + finalizer.detach() + + logger.info("[AsyncOmniEngine] Shutting down Orchestrator") + try: + if self.request_queue is not None: + self.request_queue.sync_q.put_nowait({"type": "shutdown"}) + except Exception: + pass + if self.is_alive(): + self.orchestrator_thread.join(timeout=10) + if self.orchestrator_thread.is_alive(): + logger.warning("[AsyncOmniEngine] Orchestrator thread did not exit in time") + + for q in (self.request_queue, self.output_queue, self.rpc_output_queue): + if q is None: + continue + try: + q.close() + except Exception: + pass diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py deleted file mode 100644 index 5bbd16b38d..0000000000 --- a/vllm_omni/engine/input_processor.py +++ /dev/null @@ -1,300 +0,0 @@ -import time -from collections.abc import Mapping -from typing import Any - -import numpy as np -import torch -from vllm.config import VllmConfig -from vllm.inputs import ProcessorInputs, PromptType -from vllm.inputs.parse import split_enc_dec_inputs -from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import MultiModalFeatureSpec -from vllm.multimodal.utils import argsort_mm_positions -from vllm.pooling_params import PoolingParams -from vllm.renderers import BaseRenderer -from vllm.sampling_params import SamplingParams -from vllm.tasks import SupportedTask -from vllm.utils import length_from_prompt_token_ids_or_embeds -from vllm.utils.jsontree import json_iter_leaves -from vllm.v1.engine.input_processor import InputProcessor - -from vllm_omni.engine import ( - AdditionalInformationEntry, - AdditionalInformationPayload, - OmniEngineCoreRequest, - PromptEmbedsPayload, -) -from vllm_omni.inputs.preprocess import OmniInputPreprocessor -from vllm_omni.lora.request import LoRARequest - -logger = init_logger(__name__) - -_OMNI_EXTRA_KEYS = ( - "additional_information", - "prompt_embeds", - "negative_prompt", - "negative_prompt_embeds", -) - - -def reinject_omni_fields( - results: list[ProcessorInputs], - original_prompts: list[dict], -) -> None: - """Re-inject omni-specific fields that the upstream renderer discards. - - The upstream renderer's ``process_for_engine`` creates new dicts that only - copy standard vLLM fields (prompt_token_ids, multi_modal_data, …). - Omni-specific fields such as ``additional_information`` and - ``prompt_embeds`` are silently dropped. This helper copies them back from - the *original* parsed prompts into the renderer outputs so they survive - into ``OmniInputProcessor.process_inputs()``. - """ - for result, orig in zip(results, original_prompts): - if not isinstance(orig, dict): - continue - for key in _OMNI_EXTRA_KEYS: - val = orig.get(key) - if val is not None and key not in result: - result[key] = val - - -class OmniInputProcessor(InputProcessor): - """Processor for omni models, handling multimodal inputs and embeddings. - - Extends the base vLLM Processor with support for processing prompt - embeddings and additional information payloads, enabling direct transfer - of pre-computed embeddings between pipeline stages. - - Args: - vllm_config: Global vLLM configuration - mm_registry: Multi-modal registry for processing multimodal inputs - """ - - @staticmethod - def _dtype_to_name(dtype: torch.dtype) -> str: - """Convert torch dtype to string representation. - - Args: - dtype: PyTorch dtype to convert - - Returns: - String representation of the dtype (e.g., "float32", "int64") - """ - mapping = { - torch.float32: "float32", - torch.float: "float32", - torch.float16: "float16", - torch.half: "float16", - torch.bfloat16: "bfloat16", - torch.float64: "float64", - torch.double: "float64", - torch.int64: "int64", - torch.long: "int64", - torch.int32: "int32", - torch.int: "int32", - torch.int16: "int16", - torch.short: "int16", - torch.int8: "int8", - torch.uint8: "uint8", - torch.bool: "bool", - } - return mapping.get(dtype, str(dtype).replace("torch.", "")) - - def __init__( - self, - vllm_config: VllmConfig, - renderer: BaseRenderer | None = None, - *, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - super().__init__(vllm_config, renderer=renderer, mm_registry=mm_registry) - self.input_preprocessor = OmniInputPreprocessor( - vllm_config=vllm_config, - renderer=self.renderer, - mm_registry=mm_registry, - ) - - def process_inputs( - self, - request_id: str, - prompt: PromptType | ProcessorInputs, - params: SamplingParams | PoolingParams, - supported_tasks: tuple[SupportedTask, ...] = ("generate",), - arrival_time: float | None = None, - lora_request: LoRARequest | None = None, - tokenization_kwargs: dict[str, Any] | None = None, - trace_headers: Mapping[str, str] | None = None, - priority: int = 0, - data_parallel_rank: int | None = None, - resumable: bool = False, - ) -> OmniEngineCoreRequest: - """Process input prompt into an engine core request. - - Converts a prompt (text, tokens, or multimodal) into an - OmniEngineCoreRequest that can be processed by the engine. - Handles prompt embeddings and additional information payloads - for direct transfer between stages. - - Args: - request_id: Unique identifier for this request - prompt: Input prompt (text, token IDs, embeddings, or multimodal) - params: Sampling or pooling parameters for generation - supported_tasks: Tuple of supported tasks for validation - arrival_time: Optional arrival timestamp (defaults to current time) - lora_request: Optional LoRA adapter request - tokenization_kwargs: Optional additional tokenization arguments - trace_headers: Optional tracing headers for observability - priority: Request priority (higher values processed first) - data_parallel_rank: Optional data parallel rank for distributed - inference - resumable: Whether the request supports streaming input - - Returns: - OmniEngineCoreRequest ready for the engine - - Raises: - ValueError: If data_parallel_rank is out of range or prompt_embeds - has incorrect shape - """ - self._validate_params(params, supported_tasks) - self._validate_lora(lora_request) - - parallel_config = self.vllm_config.parallel_config - dp_size = parallel_config.data_parallel_size - dp_local_size = parallel_config.data_parallel_size_local - num_ranks = dp_local_size if parallel_config.local_engines_only else dp_size - if data_parallel_rank is not None and not (0 <= data_parallel_rank < num_ranks): - raise ValueError(f"data_parallel_rank {data_parallel_rank} is out of range [0, {num_ranks}).") - - # Short-circuit for prompts already processed by the renderer - # (they carry a "type" key). Raw prompts must still go through the - # omni preprocessor which preserves additional_information, etc. - if isinstance(prompt, dict) and "type" in prompt: - if arrival_time is None: - arrival_time = prompt.get("arrival_time", time.time()) - processed_inputs: ProcessorInputs = prompt # type: ignore[assignment] - else: - if arrival_time is None: - arrival_time = time.time() - - processed_inputs = self.input_preprocessor.preprocess( - prompt, - tokenization_kwargs=tokenization_kwargs, - ) - - self._platform_validate_request(processed_inputs, params) - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - self._validate_model_inputs(encoder_inputs, decoder_inputs) - - # Normalize decoder prompt access across TypedDict variants. - if decoder_inputs["type"] == "embeds": - prompt_token_ids = None - prompt_embeds = decoder_inputs["prompt_embeds"] - else: - prompt_token_ids = decoder_inputs["prompt_token_ids"] - prompt_embeds = decoder_inputs.get("prompt_embeds") - - sampling_params = None - pooling_params = None - if isinstance(params, SamplingParams): - # TODO: can we avoid cloning here in multiproc case? - sampling_params = params.clone() - # If unset max tokens, then generate up to the max_model_len. - if sampling_params.max_tokens is None: - seq_len = length_from_prompt_token_ids_or_embeds(prompt_token_ids, prompt_embeds) - sampling_params.max_tokens = self.model_config.max_model_len - seq_len - sampling_params.update_from_generation_config( - self.generation_config_fields, - self.renderer.get_eos_token_id(), - ) - if self.tokenizer is not None: - sampling_params.update_from_tokenizer(self.tokenizer) - else: - pooling_params = params.clone() - - # Multimodal related. - mm_features: list[MultiModalFeatureSpec] | None = None - - if decoder_inputs["type"] == "multimodal": - decoder_mm_inputs = decoder_inputs["mm_kwargs"] - decoder_mm_positions = decoder_inputs["mm_placeholders"] - decoder_mm_hashes = decoder_inputs["mm_hashes"] - - if not all(isinstance(leaf, str) for leaf in json_iter_leaves(decoder_mm_hashes)): - raise ValueError( - f"mm_hashes must contain only strings, got: {decoder_mm_hashes}. " - "This is likely due to an incorrect custom implementation of " - "MultiModalProcessor.apply method." - ) - - # Merge and flatten multimodal placeholders, hashes and inputs - # from dictionaries to lists, and sort them by each item's position - # in the input sequence. - sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - - mm_features = [] - for modality, idx in sorted_mm_idxs: - base_mm_hash = decoder_mm_hashes[modality][idx] - mm_features.append( - MultiModalFeatureSpec( - data=decoder_mm_inputs[modality][idx], - modality=modality, - identifier=self._get_mm_identifier(base_mm_hash, lora_request), - mm_position=decoder_mm_positions[modality][idx], - mm_hash=base_mm_hash, - ) - ) - - # Compatibility: decode serialized prompt embeds if provided. - if isinstance(prompt_embeds, PromptEmbedsPayload): - prompt_embeds = self._decode_prompt_embeds(prompt_embeds) - - additional_information_payload: AdditionalInformationPayload | None = None - raw_info: dict[str, Any] | AdditionalInformationPayload | None = decoder_inputs.get("additional_information") - if isinstance(raw_info, AdditionalInformationPayload): - additional_information_payload = raw_info - elif raw_info is not None: - entries: dict[str, AdditionalInformationEntry] = {} - for key, value in raw_info.items(): - if isinstance(value, torch.Tensor): - v_cpu = value.detach().to("cpu").contiguous() - dtype_str = self._dtype_to_name(v_cpu.dtype) - data_bytes = v_cpu.numpy().tobytes() - entry = AdditionalInformationEntry( - tensor_data=data_bytes, - tensor_shape=[int(x) for x in list(v_cpu.shape)], - tensor_dtype=dtype_str, - ) - elif isinstance(value, list): - entry = AdditionalInformationEntry(list_data=value) - else: - raise ValueError("additional_information values must be Tensor or list") - entries[key] = entry - additional_information_payload = AdditionalInformationPayload(entries=entries) - - return OmniEngineCoreRequest( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - mm_features=mm_features, - sampling_params=sampling_params, - pooling_params=pooling_params, - arrival_time=arrival_time, - lora_request=lora_request, - cache_salt=decoder_inputs.get("cache_salt"), - priority=priority, - data_parallel_rank=data_parallel_rank, - trace_headers=trace_headers, - prompt_embeds=prompt_embeds, - additional_information=additional_information_payload, - resumable=resumable, - ) - - @staticmethod - def _decode_prompt_embeds(payload: PromptEmbedsPayload) -> torch.Tensor: - dtype = getattr(np, payload.dtype) - arr = np.frombuffer(payload.data, dtype=dtype) - arr = arr.reshape(payload.shape) - return torch.from_numpy(arr) diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py new file mode 100644 index 0000000000..a94879133a --- /dev/null +++ b/vllm_omni/engine/orchestrator.py @@ -0,0 +1,812 @@ +""" +Orchestrator for vLLM-Omni multi-stage runtime. + +Runs inside a background thread with its own asyncio event loop. +Owns all StageEngineCoreClient instances, input/output processors, +and handles stage-to-stage transfer logic. +""" + +from __future__ import annotations + +import asyncio +import copy +import time as _time +from dataclasses import dataclass, field +from typing import Any + +import janus +import torch +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import EngineCoreOutputs + +from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length +from vllm_omni.engine import ( + OmniEngineCoreRequest, +) +from vllm_omni.engine.serialization import serialize_additional_information +from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics +from vllm_omni.metrics.stats import StageStats +from vllm_omni.metrics.utils import count_tokens_from_outputs + +logger = init_logger(__name__) + + +def build_engine_core_request_from_tokens( + request_id: str, + prompt: dict[str, Any], + params: SamplingParams | PoolingParams, + arrival_time: float | None = None, + model_config: ModelConfig | None = None, +) -> OmniEngineCoreRequest: + """Build an OmniEngineCoreRequest directly from an OmniTokensPrompt. + + Lightweight alternative to the full InputProcessor pipeline - skips + tokenization, multimodal preprocessing, LoRA validation, and platform + validation. Intended for stage 1+ where the upstream stage has already + produced token IDs and optional embeddings. + """ + if arrival_time is None: + arrival_time = _time.time() + + prompt_token_ids = prompt["prompt_token_ids"] + + # Clone params and set max_tokens if needed + sampling_params = None + pooling_params = None + if isinstance(params, SamplingParams): + sampling_params = params.clone() + if sampling_params.max_tokens is None and model_config is not None: + sampling_params.max_tokens = model_config.max_model_len - len(prompt_token_ids) + else: + pooling_params = params.clone() + + prompt_embeds: torch.Tensor | None = prompt.get("prompt_embeds") + + # Serialize additional_information if present + additional_info_payload = serialize_additional_information( + prompt.get("additional_information"), + log_prefix=f"build_engine_core_request_from_tokens req={request_id}", + ) + + return OmniEngineCoreRequest( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=sampling_params, + pooling_params=pooling_params, + arrival_time=arrival_time, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + prompt_embeds=prompt_embeds, + additional_information=additional_info_payload, + ) + + +# ============================================================ +# Orchestrator internals (run inside the background thread) +# ============================================================ + + +@dataclass +class OrchestratorRequestState: + """Per-request bookkeeping inside the Orchestrator.""" + + request_id: str + prompt: Any = None + sampling_params_list: list[Any] = field(default_factory=list) + final_stage_id: int = -1 + + # Metrics: timestamp when request was submitted to each stage + stage_submit_ts: dict[int, float] = field(default_factory=dict) + + +class Orchestrator: + """Runs inside a background thread's asyncio event loop. + + Owns all StageEngineCoreClient instances, input/output processors, + and handles stage-to-stage transfer logic. + """ + + def __init__( + self, + request_async_queue: janus.AsyncQueue[dict[str, Any]], + output_async_queue: janus.AsyncQueue[dict[str, Any]], + rpc_async_queue: janus.AsyncQueue[dict[str, Any]], + stage_clients: list[Any], + output_processors: list[Any], + stage_vllm_configs: list[Any], + *, + async_chunk: bool = False, + ) -> None: + self.request_async_queue = request_async_queue + self.output_async_queue = output_async_queue + self.rpc_async_queue = rpc_async_queue + + self.num_stages = len(stage_clients) + self.async_chunk = bool(async_chunk) + + self.stage_clients: list[Any] = stage_clients + self.output_processors: list[Any] = output_processors + self.stage_vllm_configs: list[Any] = stage_vllm_configs + + # Per-request state + self.request_states: dict[str, OrchestratorRequestState] = {} + + # CFG companion tracking + self._companion_map: dict[str, dict[str, str]] = {} + self._companion_to_parent: dict[str, str] = {} + self._companion_ids: set[str] = set() + self._companion_done: dict[str, set[str]] = {} + self._deferred_parents: dict[str, dict[str, Any]] = {} + + # Per-stage metrics accumulators. + self._batch_seq: list[int] = [0] * self.num_stages + self._agg_total_tokens: list[int] = [0] * self.num_stages + self._agg_total_gen_time_ms: list[float] = [0.0] * self.num_stages + + # Shutdown coordination + self._shutdown_event = asyncio.Event() + self._stages_shutdown = False + + async def run(self) -> None: + """Main entry point for the Orchestrator event loop.""" + logger.info("[Orchestrator] Starting event loop") + + request_task = asyncio.create_task(self._request_handler(), name="orchestrator-request-handler") + output_task = asyncio.create_task( + self._orchestration_output_handler(), + name="orchestrator-stage-output-handler", + ) + + try: + # Run both tasks concurrently; if either fails the other is cancelled. + await asyncio.gather(request_task, output_task) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("[Orchestrator] Fatal error in orchestrator tasks") + raise + finally: + self._shutdown_event.set() + for t in (request_task, output_task): + if not t.done(): + t.cancel() + try: + await asyncio.gather(request_task, output_task, return_exceptions=True) + except Exception: + pass + + self._shutdown_stages() + + # Cancel any remaining tasks spawned by wait_for / gather so + # the event loop can close cleanly without "pending task" errors. + loop = asyncio.get_running_loop() + pending = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task() and not t.done()] + for t in pending: + t.cancel() + if pending: + await asyncio.gather(*pending, return_exceptions=True) + + async def _request_handler(self) -> None: + """Read messages from the main thread via request_async_queue.""" + while True: + msg = await self.request_async_queue.get() + msg_type = msg.get("type") + + if msg_type == "add_request": + await self._handle_add_request(msg) + elif msg_type == "add_companion_request": + await self._handle_add_companion(msg) + elif msg_type == "abort": + await self._handle_abort(msg) + elif msg_type == "collective_rpc": + await self._handle_collective_rpc(msg) + elif msg_type == "shutdown": + logger.info("[Orchestrator] Received shutdown signal") + self._shutdown_event.set() + self._shutdown_stages() + break + else: + logger.warning(f"[Orchestrator] Unknown message type: {msg_type}") + + async def _orchestration_output_handler(self) -> None: + """Poll all stages, handle transfers, send final outputs to main.""" + try: + await self._orchestration_loop() + except asyncio.CancelledError: + logger.debug("[Orchestrator] _orchestration_output_handler cancelled") + return + + async def _orchestration_loop(self) -> None: + """Inner loop for _orchestration_output_handler (clean cancellation). + + Control flow: poll raw → process through output processor → route. + """ + while not self._shutdown_event.is_set(): + idle = True + for stage_id in range(self.num_stages): + if self._shutdown_event.is_set(): + return + + # 1) Diffusion stage: poll non-blocking queue + # TODO (Peiqi): the output of diffusion stage is OmniRequestOutput, + # which is different from EngineCoreOutputs (LLM stages). We may want to unify + # the output format in the future to simplify the processing logic in Orchestrator. + stage_client = self.stage_clients[stage_id] + if stage_client.stage_type == "diffusion": + output = stage_client.get_diffusion_output_async() + if output is not None: + idle = False + req_state = self.request_states.get(output.request_id) + if req_state is not None: + stage_metrics = self._build_stage_metrics(stage_id, output.request_id, [output], req_state) + await self._route_output(stage_id, output, req_state, stage_metrics) + continue + + # 1) Poll raw outputs from the stage + try: + raw_outputs = await asyncio.wait_for(self._poll_stage_raw(stage_id), timeout=0.001) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + raise + except Exception: + if self._shutdown_event.is_set(): + return + logger.exception( + "[Orchestrator] _poll_stage_raw failed for stage-%s", + stage_id, + ) + raise + + if raw_outputs is None: + continue + idle = False + + # 2) Process raw outputs through the output processor + request_outputs = await self._process_stage_outputs(stage_id, raw_outputs) + + # 3) Route each processed output + for output in request_outputs: + req_state = self.request_states.get(output.request_id) + if req_state is None: + logger.warning( + "[Orchestrator] Dropping output for unknown req %s at stage-%s (known reqs: %s)", + output.request_id, + stage_id, + list(self.request_states.keys()), + ) + continue + stage_metrics = None + if output.finished: + stage_metrics = self._build_stage_metrics( + stage_id, + output.request_id, + [output], + req_state, + ) + await self._route_output(stage_id, output, req_state, stage_metrics) + + if idle: + await asyncio.sleep(0.001) + else: + await asyncio.sleep(0) + + async def _route_output( + self, + stage_id: int, + output: Any, + req_state: OrchestratorRequestState, + stage_metrics: Any, + ) -> None: + """Route a processed output: send to main thread and/or forward to next stage.""" + req_id = output.request_id + finished = output.finished + submit_ts = req_state.stage_submit_ts.get(stage_id) + stage_client = self.stage_clients[stage_id] + + # CFG companion handling: companions don't produce user-visible output + # and don't forward to the next stage directly. + if finished and req_id in self._companion_ids: + parent_id = self._companion_to_parent.get(req_id) + if parent_id is not None: + self._companion_done.setdefault(parent_id, set()).add(req_id) + logger.debug( + "[Orchestrator] CFG companion %s done (parent=%s)", + req_id, + parent_id, + ) + # Check if parent is waiting and all companions are done + if parent_id in self._deferred_parents and self._all_companions_done(parent_id): + deferred = self._deferred_parents.pop(parent_id) + parent_state = self.request_states.get(parent_id) + if parent_state is not None: + await self._forward_to_next_stage( + parent_id, + deferred["stage_id"], + deferred["output"], + parent_state, + ) + self.request_states.pop(req_id, None) + return + + if stage_client.final_output: + await self.output_async_queue.put( + { + "type": "output", + "request_id": req_id, + "stage_id": stage_id, + "engine_outputs": output, + "metrics": stage_metrics, + "finished": finished and stage_id == req_state.final_stage_id, + "stage_submit_ts": submit_ts, + } + ) + elif stage_metrics is not None: + await self.output_async_queue.put( + { + "type": "stage_metrics", + "request_id": req_id, + "stage_id": stage_id, + "metrics": stage_metrics, + "stage_submit_ts": submit_ts, + } + ) + + if finished and stage_id < req_state.final_stage_id and not self.async_chunk: + # If this parent has CFG companions, defer forwarding until all done + if req_id in self._companion_map and not self._all_companions_done(req_id): + self._deferred_parents[req_id] = { + "stage_id": stage_id, + "output": output, + } + logger.debug( + "[Orchestrator] Parent %s deferred, waiting for CFG companions", + req_id, + ) + else: + await self._forward_to_next_stage(req_id, stage_id, output, req_state) + + if finished and stage_id == req_state.final_stage_id: + self._cleanup_companion_state(req_id) + self.request_states.pop(req_id, None) + + def _cleanup_companion_state(self, parent_id: str) -> None: + """Remove all companion tracking state for a completed parent.""" + role_map = self._companion_map.pop(parent_id, {}) + for cid in role_map.values(): + self._companion_ids.discard(cid) + self._companion_to_parent.pop(cid, None) + self._companion_done.pop(parent_id, None) + self._deferred_parents.pop(parent_id, None) + + def _all_companions_done(self, parent_id: str) -> bool: + """Check whether all CFG companions for a parent request have finished.""" + role_map = self._companion_map.get(parent_id, {}) + if not role_map: + return True + done_set = self._companion_done.get(parent_id, set()) + return all(cid in done_set for cid in role_map.values()) + + def _build_stage_metrics( + self, + stage_id: int, + req_id: str, + request_outputs: list[RequestOutput], + req_state: OrchestratorRequestState, + ) -> StageRequestMetrics: + """Build StageRequestMetrics for a finished request at a stage. + + Reuses StageRequestMetrics so OrchestratorMetrics and downstream + metric handlers can consume a stable schema. + """ + now = _time.time() + submit_ts = req_state.stage_submit_ts.get(stage_id, now) + stage_gen_time_ms = (now - submit_ts) * 1000.0 + + num_tokens_out = count_tokens_from_outputs(request_outputs) + num_tokens_in = 0 + if stage_id == 0: + for ro in request_outputs: + ptids = getattr(ro, "prompt_token_ids", None) + if ptids is not None: + num_tokens_in += len(ptids) + + # Monotonic batch counter per stage. + self._batch_seq[stage_id] += 1 + batch_id = self._batch_seq[stage_id] + + # Accumulate for running-average stage_stats + self._agg_total_tokens[stage_id] += num_tokens_out + self._agg_total_gen_time_ms[stage_id] += stage_gen_time_ms + + return StageRequestMetrics( + num_tokens_in=num_tokens_in, + num_tokens_out=num_tokens_out, + stage_gen_time_ms=stage_gen_time_ms, + batch_id=batch_id, + batch_size=1, + rx_decode_time_ms=0.0, + rx_transfer_bytes=0, + rx_in_flight_time_ms=0.0, + stage_stats=StageStats( + total_token=self._agg_total_tokens[stage_id], + total_gen_time_ms=self._agg_total_gen_time_ms[stage_id], + ), + ) + + async def _forward_to_next_stage( + self, + req_id: str, + stage_id: int, + output: Any, + req_state: OrchestratorRequestState, + ) -> None: + """Forward output from current stage to the next stage. + + Handles the full pipeline: set outputs on current stage, compute + next-stage inputs, build lightweight requests, and submit them. + """ + next_stage_id = stage_id + 1 + next_client = self.stage_clients[next_stage_id] + params = req_state.sampling_params_list[next_stage_id] + + if next_client.stage_type == "diffusion": + self.stage_clients[stage_id].set_engine_outputs([output]) + if next_client.custom_process_input_func is not None: + diffusion_prompt = next_client.custom_process_input_func( + self.stage_clients, + next_client.engine_input_source, + req_state.prompt, + False, + ) + if isinstance(diffusion_prompt, list): + diffusion_prompt = diffusion_prompt[0] + else: + diffusion_prompt = req_state.prompt + + # Attach CFG companion KV request IDs so the diffusion model + # runner can fetch companion KV caches alongside the primary one. + cfg_ids = self._companion_map.get(req_id) + if cfg_ids: + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + if isinstance(params, OmniDiffusionSamplingParams): + params = copy.deepcopy(params) + params.cfg_kv_request_ids = cfg_ids + logger.info( + "[Orchestrator] Attaching cfg_kv_request_ids=%s to req %s", + cfg_ids, + req_id, + ) + + await next_client.add_request_async(req_id, diffusion_prompt, params) + req_state.stage_submit_ts[next_stage_id] = _time.time() + return + + self.stage_clients[stage_id].set_engine_outputs([output]) + + # Process inputs for next stage + try: + next_inputs = next_client.process_engine_inputs( + stage_list=self.stage_clients, + prompt=req_state.prompt, + ) + except Exception: + logger.exception( + "[Orchestrator] req=%s process_engine_inputs FAILED for stage-%s", + req_id, + next_stage_id, + ) + raise + + # Build and submit requests for each input + for next_input in next_inputs: + request = build_engine_core_request_from_tokens( + request_id=req_id, + prompt=next_input, + params=params, + model_config=self.stage_vllm_configs[next_stage_id].model_config, + ) + + # TODO: Here we directly use the req id to assign. + request.external_req_id = request.request_id + + self.output_processors[next_stage_id].add_request( + request=request, + prompt=None, + parent_req=None, + request_index=0, + queue=None, + ) + + await next_client.add_request_async(request) + + # Record submit timestamp for the next stage + req_state.stage_submit_ts[next_stage_id] = _time.time() + + async def _poll_stage_raw(self, stage_id: int) -> EngineCoreOutputs | None: + """Pull raw EngineCoreOutputs from a stage client without processing. + + Returns the raw outputs object, or None when there is nothing + to consume. + """ + outputs = await self.stage_clients[stage_id].get_output_async() + if not outputs.outputs: + return None + return outputs + + async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> list[RequestOutput]: + """Run the output processor on raw outputs, returning RequestOutputs. + + Also handles abort forwarding and scheduler stats updates. + """ + processor = self.output_processors[stage_id] + + processed = processor.process_outputs( + raw_outputs.outputs, + raw_outputs.timestamp, + None, + ) + + if processed.reqs_to_abort: + await self.stage_clients[stage_id].abort_requests_async(processed.reqs_to_abort) + + if raw_outputs.scheduler_stats is not None: + processor.update_scheduler_stats(raw_outputs.scheduler_stats) + + return processed.request_outputs + + async def _handle_add_request(self, msg: dict[str, Any]) -> None: + """Handle an add_request message from the main thread.""" + stage_id = 0 + request_id = msg["request_id"] + prompt = msg["prompt"] + original_prompt = msg.get("original_prompt", prompt) + sampling_params_list = msg["sampling_params_list"] + if not sampling_params_list: + raise ValueError(f"Missing sampling params for stage 0. Got {len(sampling_params_list)} stage params.") + params = sampling_params_list[0] + final_stage_id = msg["final_stage_id"] + + logger.info( + "[Orchestrator] _handle_add_request: stage=%s req=%s " + "prompt_type=%s original_prompt_type=%s final_stage=%s " + "num_sampling_params=%d", + stage_id, + request_id, + type(prompt).__name__, + type(original_prompt).__name__, + final_stage_id, + len(sampling_params_list), + ) + + # Track request state - use original_prompt so downstream stages + # (e.g. thinker2talker) can access the raw dict with multi_modal_data. + req_state = OrchestratorRequestState( + request_id=request_id, + prompt=original_prompt, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id, + ) + req_state.stage_submit_ts[stage_id] = _time.time() + self.request_states[request_id] = req_state + + # Stage-0 prompt is already a fully-formed OmniEngineCoreRequest + # (pre-processed by AsyncOmniEngine.add_request, output processor + # already registered there) - submit directly. + request = prompt + stage_client = self.stage_clients[stage_id] + if stage_client.stage_type == "diffusion": + await stage_client.add_request_async(request_id, prompt, params) + else: + await stage_client.add_request_async(request) + + if self.async_chunk and stage_id == 0 and final_stage_id > 0: + await self._prewarm_async_chunk_stages(request_id, request, req_state) + + async def _prewarm_async_chunk_stages( + self, + request_id: str, + stage0_request: Any, + req_state: OrchestratorRequestState, + ) -> None: + """Pre-submit downstream stages for async-chunk mode. + + In async-chunk mode, stages exchange data through connectors/chunk adapters, + so downstream stages should be armed once at request start instead of waiting + for stage-finished forwarding. + """ + if req_state.final_stage_id <= 0: + return + + prompt_token_ids = getattr(stage0_request, "prompt_token_ids", None) + if prompt_token_ids is None: + logger.warning( + "[Orchestrator] async_chunk prewarm skipped for req=%s: stage0 prompt_token_ids missing", + request_id, + ) + return + + # Pre-arm stage-1+ with placeholder prompt IDs. + try: + next_prompt_len = max(1, compute_talker_prompt_ids_length(prompt_token_ids)) + except Exception: + next_prompt_len = max(1, len(prompt_token_ids)) + original_prompt = req_state.prompt + if isinstance(original_prompt, dict): + base_input = copy.deepcopy(original_prompt) + else: + base_input = {} + base_input["prompt_token_ids"] = [0] * next_prompt_len + base_input["multi_modal_data"] = None + base_input["mm_processor_kwargs"] = None + + for next_stage_id in range(1, req_state.final_stage_id + 1): + next_client = self.stage_clients[next_stage_id] + params = req_state.sampling_params_list[next_stage_id] + + if next_client.stage_type == "diffusion": + await next_client.add_request_async(request_id, req_state.prompt, params) + req_state.stage_submit_ts[next_stage_id] = _time.time() + continue + + request = build_engine_core_request_from_tokens( + request_id=request_id, + prompt=base_input, + params=params, + model_config=self.stage_vllm_configs[next_stage_id].model_config, + ) + request.external_req_id = request.request_id + + self.output_processors[next_stage_id].add_request( + request=request, + prompt=None, + parent_req=None, + request_index=0, + queue=None, + ) + await next_client.add_request_async(request) + req_state.stage_submit_ts[next_stage_id] = _time.time() + + async def _handle_add_companion(self, msg: dict[str, Any]) -> None: + """Handle an add_companion_request message: submit companion to stage 0.""" + companion_id = msg["companion_id"] + parent_id = msg["parent_id"] + role = msg["role"] + companion_prompt = msg["prompt"] + sampling_params_list = msg["sampling_params_list"] + + # Register companion mapping + if parent_id not in self._companion_map: + self._companion_map[parent_id] = {} + self._companion_map[parent_id][role] = companion_id + self._companion_ids.add(companion_id) + self._companion_to_parent[companion_id] = parent_id + self._companion_done.setdefault(parent_id, set()) + + companion_state = OrchestratorRequestState( + request_id=companion_id, + prompt=companion_prompt, + sampling_params_list=sampling_params_list, + final_stage_id=0, + ) + companion_state.stage_submit_ts[0] = _time.time() + self.request_states[companion_id] = companion_state + + request = companion_prompt # Already a processed OmniEngineCoreRequest + stage_client = self.stage_clients[0] + await stage_client.add_request_async(request) + + logger.info( + "[Orchestrator] CFG companion submitted: %s (role=%s, parent=%s)", + companion_id, + role, + parent_id, + ) + + async def _handle_abort(self, msg: dict[str, Any]) -> None: + """Handle an abort message from the main thread.""" + request_ids = msg["request_ids"] + # Also abort any CFG companions for aborted parents + companion_ids_to_abort: list[str] = [] + for req_id in request_ids: + role_map = self._companion_map.pop(req_id, {}) + for cid in role_map.values(): + companion_ids_to_abort.append(cid) + self._companion_ids.discard(cid) + self._companion_to_parent.pop(cid, None) + self.request_states.pop(cid, None) + self._companion_done.pop(req_id, None) + self._deferred_parents.pop(req_id, None) + + all_ids_to_abort = list(request_ids) + companion_ids_to_abort + for stage_id in range(self.num_stages): + await self.stage_clients[stage_id].abort_requests_async(all_ids_to_abort) + for req_id in request_ids: + self.request_states.pop(req_id, None) + logger.info("[Orchestrator] Aborted request(s) %s", request_ids) + + async def _handle_collective_rpc(self, msg: dict[str, Any]) -> None: + """Handle a control-plane RPC request from the main thread. + + TODO(AsyncOmni): parallelize stage dispatch if control latency becomes + noticeable. The current sequential fanout keeps the first version simple + and deterministic. + """ + rpc_id = msg["rpc_id"] + method = msg["method"] + timeout = msg.get("timeout") + args = tuple(msg.get("args", ())) + kwargs = dict(msg.get("kwargs") or {}) + requested_stage_ids = msg.get("stage_ids") + stage_ids = list(range(self.num_stages)) if requested_stage_ids is None else list(requested_stage_ids) + + results: list[Any] = [] + for stage_id in stage_ids: + if stage_id < 0 or stage_id >= self.num_stages: + results.append( + { + "supported": False, + "todo": True, + "error": f"Invalid stage id {stage_id}", + } + ) + continue + + stage_client = self.stage_clients[stage_id] + try: + if hasattr(stage_client, "collective_rpc_async"): + stage_result = await stage_client.collective_rpc_async( + method=method, + timeout=timeout, + args=args, + kwargs=kwargs, + ) + else: + stage_result = { + "supported": False, + "todo": True, + "reason": (f"{stage_client.__class__.__name__}.collective_rpc_async is not implemented yet"), + } + except Exception as exc: + logger.exception( + "[Orchestrator] collective_rpc failed: stage=%s method=%s", + stage_id, + method, + ) + stage_result = { + "supported": False, + "error": str(exc), + } + + results.append(stage_result) + + await self.rpc_async_queue.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "method": method, + "stage_ids": stage_ids, + "results": results, + } + ) + + def _shutdown_stages(self) -> None: + """Shutdown all stage clients.""" + if self._stages_shutdown: + return + + self._stages_shutdown = True + logger.info("[Orchestrator] Shutting down all stages") + for stage_id, stage_client in enumerate(self.stage_clients): + try: + stage_client.shutdown() + logger.info(f"[Orchestrator] Stage {stage_id} shut down") + except Exception as e: + logger.warning(f"[Orchestrator] Failed to shutdown stage {stage_id}: {e}") diff --git a/vllm_omni/engine/serialization.py b/vllm_omni/engine/serialization.py new file mode 100644 index 0000000000..41146a01b6 --- /dev/null +++ b/vllm_omni/engine/serialization.py @@ -0,0 +1,81 @@ +"""Shared serialization helpers for omni engine request payloads.""" + +from __future__ import annotations + +from typing import Any + +import torch +from vllm.logger import init_logger + +from vllm_omni.engine import ( + AdditionalInformationEntry, + AdditionalInformationPayload, +) + +logger = init_logger(__name__) + + +def dtype_to_name(dtype: torch.dtype) -> str: + """Convert torch dtype to a stable string name for serialization.""" + mapping = { + torch.float32: "float32", + torch.float: "float32", + torch.float16: "float16", + torch.half: "float16", + torch.bfloat16: "bfloat16", + torch.float64: "float64", + torch.double: "float64", + torch.int64: "int64", + torch.long: "int64", + torch.int32: "int32", + torch.int: "int32", + torch.int16: "int16", + torch.short: "int16", + torch.int8: "int8", + torch.uint8: "uint8", + torch.bool: "bool", + } + return mapping.get(dtype, str(dtype).replace("torch.", "")) + + +def serialize_additional_information( + raw_info: dict[str, Any] | AdditionalInformationPayload | None, + *, + log_prefix: str | None = None, +) -> AdditionalInformationPayload | None: + """Serialize omni request metadata for EngineCore transport.""" + if raw_info is None: + return None + if isinstance(raw_info, AdditionalInformationPayload): + return raw_info + + entries: dict[str, AdditionalInformationEntry] = {} + for key, value in raw_info.items(): + if isinstance(value, torch.Tensor): + value_cpu = value.detach().to("cpu").contiguous() + entries[key] = AdditionalInformationEntry( + tensor_data=value_cpu.numpy().tobytes(), + tensor_shape=list(value_cpu.shape), + tensor_dtype=dtype_to_name(value_cpu.dtype), + ) + continue + + if isinstance(value, list): + entries[key] = AdditionalInformationEntry(list_data=value) + continue + + if log_prefix is None: + logger.warning( + "Dropping unsupported additional_information key=%s type=%s", + key, + type(value).__name__, + ) + else: + logger.warning( + "[%s] Dropping unsupported additional_information key=%s type=%s", + log_prefix, + key, + type(value).__name__, + ) + + return AdditionalInformationPayload(entries=entries) if entries else None diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py new file mode 100644 index 0000000000..a4e8bc0e93 --- /dev/null +++ b/vllm_omni/engine/stage_engine_core_client.py @@ -0,0 +1,168 @@ +""" +Stage Engine Core Client for vLLM-Omni multi-stage runtime. + +Directly inherits from vLLM's AsyncMPClient to reuse EngineCore architecture. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.core_client import AsyncMPClient + +from vllm_omni.engine.stage_init_utils import StageMetadata + +if TYPE_CHECKING: + from vllm.v1.engine import EngineCoreOutput + + from vllm_omni.inputs.data import OmniTokensPrompt + +logger = init_logger(__name__) + + +class StageEngineCoreClient(AsyncMPClient): + """Stage async client that inherits from vLLM's AsyncMPClient. + + Fully reuses AsyncMPClient.__init__ for: + - ZMQ setup, sockets + - launch_core_engines() -> EngineCoreProc + - outputs_queue, output_queue_task + - All utility methods (shutdown, get_output_async, abort_requests_async, etc.) + + This is the async version of StageMPClient, designed for use with AsyncOmniEngine. + """ + + def __init__( + self, + vllm_config: Any, + executor_class: type, + metadata: StageMetadata, + client_addresses: dict[str, str] | None = None, + engine_manager: Any = None, + coordinator: Any = None, + ): + """Create an async EngineCore client for a single stage. + + All heavy init (config extraction, plugin loading, device setup, + engine args building, device locking) is done by the Orchestrator + via helpers in stage_init_utils.py. This constructor just stores metadata + and calls super().__init__(). + """ + # -------- Stage metadata (public fields used at runtime) -------- + self.stage_id = metadata.stage_id + self.stage_type = metadata.stage_type + self.engine_output_type = metadata.engine_output_type + self.is_comprehension = metadata.is_comprehension + self.requires_multimodal_data = metadata.requires_multimodal_data + self.engine_input_source = metadata.engine_input_source + self.final_output = metadata.final_output + self.final_output_type = metadata.final_output_type + self.default_sampling_params = metadata.default_sampling_params + self.custom_process_input_func = metadata.custom_process_input_func + self.model_stage = metadata.model_stage + + self.engine_outputs: Any = None + + logger.info( + "[StageEngineCoreClient] Stage-%s initializing EngineCore", + self.stage_id, + ) + try: + super().__init__( + vllm_config, + executor_class, + log_stats=False, + client_addresses=client_addresses, + ) + if engine_manager is not None: + self.resources.engine_manager = engine_manager + if coordinator is not None: + self.resources.coordinator = coordinator + except Exception: + logger.exception( + "[StageEngineCoreClient] Stage-%s EngineCore init failed", + self.stage_id, + ) + try: + self.shutdown() + except Exception as shutdown_error: + logger.warning( + "[StageEngineCoreClient] Stage-%s cleanup after init failure failed: %s", + self.stage_id, + shutdown_error, + ) + raise + logger.info( + "[StageEngineCoreClient] Stage-%s EngineCore running", + self.stage_id, + ) + + # ==================== Overrides ==================== + + async def add_request_async(self, request: EngineCoreRequest) -> None: + """Add request to the stage engine core.""" + logger.info(f"[StageEngineCoreClient] Stage-{self.stage_id} adding request: {request.request_id}") + await super().add_request_async(request) + + # ==================== Stage Methods ==================== + + def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: + """Set engine outputs (called by orchestrator).""" + self.engine_outputs = engine_outputs + + def process_engine_inputs( + self, + stage_list: list[Any], + prompt: OmniTokensPrompt | list[OmniTokensPrompt] | None = None, + ) -> list[OmniTokensPrompt]: + """Process inputs from upstream stages.""" + from vllm_omni.inputs.data import OmniTokensPrompt + + if self.custom_process_input_func is not None: + return self.custom_process_input_func( + stage_list, + self.engine_input_source, + prompt, + self.requires_multimodal_data, + ) + + if not self.engine_input_source: + raise ValueError(f"engine_input_source empty for stage {self.stage_id}") + + source_id = self.engine_input_source[0] + source_outputs = stage_list[source_id].engine_outputs + + if not isinstance(prompt, list): + prompt = [prompt] + + mm_data = {so.request_id: p.get("multi_modal_data") for so, p in zip(source_outputs, prompt)} + + return [ + OmniTokensPrompt( + prompt_token_ids=so.outputs[0].token_ids, + multi_modal_data=(mm_data[so.request_id] if self.requires_multimodal_data else None), + ) + for so in source_outputs + ] + + async def collective_rpc_async( + self, + method: str, + timeout: float | None = None, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + """Forward control RPCs to the underlying AsyncMPClient stage engine. + + Each ``StageEngineCoreClient`` already represents one logical stage, so + stage-scoped control operations should be executed here and then fanned + in-core across the workers managed by this EngineCore client. + """ + return await super().collective_rpc_async( + method=method, + timeout=timeout, + args=args, + kwargs=kwargs, + ) diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py new file mode 100644 index 0000000000..a15eb08c8e --- /dev/null +++ b/vllm_omni/engine/stage_init_utils.py @@ -0,0 +1,525 @@ +""" +Stage initialization helpers for vLLM-Omni multi-stage runtime. + +Extracts orchestration-level init logic (config extraction, plugin loading, +multiprocessing setup, device mapping, device locking, engine args building) +out of StageEngineCoreClient into reusable functions. +""" + +from __future__ import annotations + +import fcntl +import importlib +import multiprocessing as mp +import os +import time +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Literal + +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine.input_processor import InputProcessor +from vllm.v1.executor import Executor + +from vllm_omni.engine.arg_utils import OmniEngineArgs +from vllm_omni.entrypoints.stage_utils import _to_dict, set_stage_devices +from vllm_omni.entrypoints.utils import filter_dataclass_kwargs, resolve_model_config_path +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams + +logger = init_logger(__name__) + + +def _resolve_model_to_local_path(model: str) -> str: + """Resolve an HF Hub model ID to a local cache path.""" + if os.path.isdir(model): + return model + + try: + from huggingface_hub import snapshot_download + + # Keep init path resolution offline-friendly. + return snapshot_download(model, local_files_only=True) + except Exception: + logger.warning( + "[stage_init] Could not resolve %s to local snapshot; using as-is", + model, + ) + return model + + +def _resolve_model_tokenizer_paths(model: str, engine_args: dict[str, Any]) -> str: + """Apply model_subdir/tokenizer_subdir indirections from stage engine args.""" + model_subdir = engine_args.pop("model_subdir", None) + tokenizer_subdir = engine_args.pop("tokenizer_subdir", None) + if model_subdir is None and tokenizer_subdir is None: + return model + + resolved_base = _resolve_model_to_local_path(model) + + if model_subdir: + model = os.path.join(resolved_base, model_subdir) + logger.info("[stage_init] Using model subdirectory: %s", model) + + if tokenizer_subdir is not None: + tokenizer_path = os.path.join(resolved_base, tokenizer_subdir) if tokenizer_subdir else resolved_base + engine_args["tokenizer"] = tokenizer_path + logger.info("[stage_init] Using tokenizer from: %s", tokenizer_path) + elif model_subdir and "tokenizer" not in engine_args: + # Keep legacy behavior: model in subdir, tokenizer defaults to base path. + engine_args["tokenizer"] = resolved_base + logger.info("[stage_init] Using tokenizer from base model path: %s", resolved_base) + + return model + + +def resolve_worker_cls(engine_args: dict[str, Any]) -> None: + """Resolve worker_cls from worker_type for non-diffusion stages.""" + worker_type = engine_args.get("worker_type", None) + if not worker_type: + return + worker_cls = engine_args.get("worker_cls") + if worker_cls is not None and worker_cls != "auto": + return + + from vllm_omni.platforms import current_omni_platform + + worker_type = str(worker_type).lower() + if worker_type == "ar": + engine_args["worker_cls"] = current_omni_platform.get_omni_ar_worker_cls() + elif worker_type == "generation": + engine_args["worker_cls"] = current_omni_platform.get_omni_generation_worker_cls() + else: + raise ValueError(f"Unknown worker_type: {worker_type}") + + +@dataclass +class StageMetadata: + """Lightweight stage attributes extracted from stage_config.""" + + stage_id: int + stage_type: Literal["llm", "diffusion"] + engine_output_type: str | None + is_comprehension: bool + requires_multimodal_data: bool + engine_input_source: list[int] + final_output: bool + final_output_type: str | None + default_sampling_params: OmniSamplingParams + custom_process_input_func: Callable | None + model_stage: str | None + runtime_cfg: Any + prompt_expand_func: Callable | None = None + cfg_kv_collect_func: Callable | None = None + + +@dataclass +class StartedLlmStage: + """Resources for an LLM stage that has completed startup.""" + + stage_id: int + metadata: Any + vllm_config: Any + executor_class: type + engine_manager: Any + coordinator: Any + addresses: Any + + +def extract_stage_metadata(stage_config: Any) -> StageMetadata: + """Pure data extraction from a stage_config object.""" + stage_id: int = stage_config.stage_id + stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm") + engine_args = stage_config.engine_args + runtime_cfg = getattr(stage_config, "runtime", {}) + engine_input_source: list[int] = getattr(stage_config, "engine_input_source", []) + final_output: bool = getattr(stage_config, "final_output", False) + final_output_type: str | None = getattr(stage_config, "final_output_type", None) + + default_sp = _to_dict(getattr(stage_config, "default_sampling_params", {})) + SPClass = SamplingParams if stage_type == "llm" else OmniDiffusionSamplingParams + default_sampling_params: OmniSamplingParams = SPClass(**default_sp) + + custom_process_input_func: Callable | None = None + if hasattr(stage_config, "custom_process_input_func"): + mod_path, fn_name = stage_config.custom_process_input_func.rsplit(".", 1) + custom_process_input_func = getattr(importlib.import_module(mod_path), fn_name) + + prompt_expand_func: Callable | None = None + _pef_path = getattr(stage_config, "prompt_expand_func", None) + if _pef_path: + _mod, _fn = _pef_path.rsplit(".", 1) + prompt_expand_func = getattr(importlib.import_module(_mod), _fn) + + cfg_kv_collect_func: Callable | None = None + _ckf_path = getattr(stage_config, "cfg_kv_collect_func", None) + if _ckf_path: + _mod, _fn = _ckf_path.rsplit(".", 1) + cfg_kv_collect_func = getattr(importlib.import_module(_mod), _fn) + + if stage_type == "diffusion": + return StageMetadata( + stage_id=stage_id, + stage_type="diffusion", + engine_output_type=None, + is_comprehension=False, + requires_multimodal_data=False, + engine_input_source=engine_input_source, + final_output=final_output, + final_output_type=final_output_type, + default_sampling_params=default_sampling_params, + custom_process_input_func=custom_process_input_func, + model_stage=None, + runtime_cfg=runtime_cfg, + cfg_kv_collect_func=cfg_kv_collect_func, + ) + + model_stage = getattr(engine_args, "model_stage", None) + engine_output_type = getattr(engine_args, "engine_output_type", None) + is_comprehension = getattr(stage_config, "is_comprehension", False) + requires_multimodal_data = getattr(runtime_cfg, "requires_multimodal_data", False) + + return StageMetadata( + stage_id=stage_id, + stage_type=stage_type, + engine_output_type=engine_output_type, + is_comprehension=is_comprehension, + requires_multimodal_data=requires_multimodal_data, + engine_input_source=engine_input_source, + final_output=final_output, + final_output_type=final_output_type, + default_sampling_params=default_sampling_params, + custom_process_input_func=custom_process_input_func, + model_stage=model_stage, + runtime_cfg=runtime_cfg, + prompt_expand_func=prompt_expand_func, + ) + + +def prepare_engine_environment() -> None: + """One-time global setup: load plugins, set multiprocessing spawn method.""" + from vllm_omni.plugins import load_omni_general_plugins + + load_omni_general_plugins() + + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + logger.info("[stage_init] Set VLLM_WORKER_MULTIPROC_METHOD=spawn") + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + pass + + +def setup_stage_devices(stage_id: int, runtime_cfg: Any) -> None: + """Device mapping via set_stage_devices for a single stage.""" + try: + from vllm_omni.platforms import current_omni_platform + + device_type = current_omni_platform.device_type + set_stage_devices( + stage_id, + runtime_cfg.get("devices") if hasattr(runtime_cfg, "get") else None, + device_type=device_type, + ) + logger.info( + "[stage_init] Stage-%s set devices for %s, runtime devices: %s", + stage_id, + device_type, + runtime_cfg.get("devices") if hasattr(runtime_cfg, "get") else None, + ) + except Exception as e: + logger.warning("Device setup failed for stage %s: %s", stage_id, e) + + +def build_engine_args_dict( + stage_config: Any, + model: str, + stage_connector_spec: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build the normalized engine args dict for one stage.""" + engine_args = stage_config.engine_args + stage_type = getattr(stage_config, "stage_type", "llm") + stage_id = stage_config.stage_id + + engine_args_dict = _to_dict(engine_args) + model = _resolve_model_tokenizer_paths(model, engine_args_dict) + engine_args_dict["model"] = model + # Stage id must come from stage config instead of inherited CLI kwargs + # (e.g. `--stage-id` defaulting to None). + engine_args_dict["stage_id"] = stage_id + if engine_args_dict.get("async_chunk", False): + engine_args_dict["stage_connector_spec"] = dict(stage_connector_spec or {}) + + if stage_type != "diffusion": + resolve_worker_cls(engine_args_dict) + + return engine_args_dict + + +def build_vllm_config( + stage_config: Any, + model: str, + stage_connector_spec: dict[str, Any] | None = None, + engine_args_dict: dict[str, Any] | None = None, +) -> tuple[Any, type]: + """Build engine args, then create VllmConfig and executor_class. + + Returns: + (vllm_config, executor_class) + """ + if engine_args_dict is None: + engine_args_dict = build_engine_args_dict( + stage_config, + model, + stage_connector_spec=stage_connector_spec, + ) + + filtered_engine_args_dict = filter_dataclass_kwargs(OmniEngineArgs, engine_args_dict) + omni_engine_args = OmniEngineArgs(**filtered_engine_args_dict) + vllm_config = omni_engine_args.create_engine_config(usage_context=UsageContext.LLM_CLASS) + executor_class = Executor.get_class(vllm_config) + + return vllm_config, executor_class + + +def acquire_device_locks( + stage_id: int, + engine_args_dict: dict[str, Any], + stage_init_timeout: int = 300, +) -> list[int]: + """Acquire exclusive file locks on devices needed by this stage. + + Returns list of lock file descriptors that must be released after init. + """ + lock_fds: list[int] = [] + try: + from vllm_omni.platforms import current_omni_platform + + # Get parallel sizes + if "parallel_config" in engine_args_dict: + pc = engine_args_dict["parallel_config"] + tensor_parallel_size = pc.get("tensor_parallel_size", 1) + pipeline_parallel_size = pc.get("pipeline_parallel_size", 1) + data_parallel_size = pc.get("data_parallel_size", 1) + prefill_context_parallel_size = pc.get("prefill_context_parallel_size", 1) + sequence_parallel_size = pc.get("sequence_parallel_size", 1) + cfg_parallel_size = pc.get("cfg_parallel_size", 1) + else: + tensor_parallel_size = engine_args_dict.get("tensor_parallel_size", 1) + pipeline_parallel_size = engine_args_dict.get("pipeline_parallel_size", 1) + data_parallel_size = engine_args_dict.get("data_parallel_size", 1) + prefill_context_parallel_size = engine_args_dict.get("prefill_context_parallel_size", 1) + sequence_parallel_size = 1 + cfg_parallel_size = 1 + + num_devices_per_stage = ( + tensor_parallel_size + * pipeline_parallel_size + * data_parallel_size + * prefill_context_parallel_size + * sequence_parallel_size + * cfg_parallel_size + ) + + # Get physical device IDs + device_control_env = current_omni_platform.device_control_env_var + visible_devices_str = os.environ.get(device_control_env) + physical_devices: list[int] = [] + + if visible_devices_str: + try: + physical_devices = [int(x.strip()) for x in visible_devices_str.split(",") if x.strip()] + except (ValueError, IndexError): + pass + + if not physical_devices: + num_devices = current_omni_platform.get_device_count() + physical_devices = list(range(num_devices)) + + num_devices_to_lock = min(num_devices_per_stage, len(physical_devices)) + devices_to_lock = sorted(physical_devices[:num_devices_to_lock]) + + logger.debug( + "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d, SP=%d, CFG=%d; will lock %d devices: %s", + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, + prefill_context_parallel_size, + sequence_parallel_size, + cfg_parallel_size, + num_devices_to_lock, + devices_to_lock, + ) + + # Acquire locks + wait_start = time.time() + for device_id in devices_to_lock: + lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock" + lock_acquired = False + + while not lock_acquired: + try: + lock_fd = os.open(lock_file, os.O_CREAT | os.O_RDWR, 0o644) + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + os.ftruncate(lock_fd, 0) + os.write(lock_fd, f"{os.getpid()}\n".encode()) + os.fsync(lock_fd) + lock_acquired = True + lock_fds.append(lock_fd) + logger.debug("Acquired exclusive lock for device %s", device_id) + except BlockingIOError: + os.close(lock_fd) + if time.time() - wait_start > stage_init_timeout: + logger.warning( + "Timeout waiting for device %s initialization lock, proceeding anyway", + device_id, + ) + break + time.sleep(0.01) + except OSError as e: + logger.debug( + "Failed to acquire lock for device %s: %s, continuing anyway", + device_id, + e, + ) + try: + os.close(lock_fd) + except (OSError, NameError): + pass + break + + except Exception as e: + logger.debug( + "[Stage-%s] Failed to set up sequential initialization lock: %s", + stage_id, + e, + ) + + return lock_fds + + +def release_device_locks(lock_fds: list[int]) -> None: + """Release file locks acquired by acquire_device_locks.""" + for lock_fd in lock_fds: + try: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + os.close(lock_fd) + logger.debug("Released initialization lock (fd=%s)", lock_fd) + except (OSError, ValueError): + pass + + +def load_omni_transfer_config_for_model(model: str, config_path: str | None) -> Any: + """Load omni transfer config from an explicit path or resolved model config.""" + from vllm_omni.distributed.omni_connectors import load_omni_transfer_config + + try: + resolved_config_path = config_path or resolve_model_config_path(model) + return load_omni_transfer_config(resolved_config_path) + except Exception as e: + logger.warning("[stage_init] Failed to load transfer config: %s", e) + return None + + +def get_stage_connector_spec( + omni_transfer_config: Any, + stage_id: int, + async_chunk: bool, +) -> dict[str, Any]: + """Return the first connector spec for the stage when async chunking is enabled.""" + from vllm_omni.distributed.omni_connectors import get_stage_connector_config + + if not async_chunk: + return {} + + stage_connectors_cfg = get_stage_connector_config(omni_transfer_config, stage_id) + for cfg in stage_connectors_cfg.values(): + return dict(cfg.get("spec", {})) + return {} + + +def initialize_diffusion_stage(model: str, stage_cfg: Any, metadata: StageMetadata) -> Any: + """Build a diffusion stage client.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient + + od_config = OmniDiffusionConfig.from_kwargs( + model=model, + **_to_dict(stage_cfg.engine_args), + ) + if metadata.cfg_kv_collect_func is not None: + od_config.cfg_kv_collect_func = metadata.cfg_kv_collect_func + return StageDiffusionClient(model, od_config, metadata) + + +def close_started_llm_stage(started: StartedLlmStage) -> None: + """Close managers owned by a launched stage that never attached.""" + resources = ( + ("engine manager", started.engine_manager), + ("coordinator", started.coordinator), + ) + for resource_name, resource in resources: + if resource is None: + continue + try: + resource.close() + except Exception as cleanup_error: + logger.warning( + "[stage_init] Failed to close launched %s for stage %s: %s", + resource_name, + started.stage_id, + cleanup_error, + ) + + +def finalize_initialized_stages( + stage_clients: list[Any | None], + input_processor: InputProcessor | None, +) -> tuple[list[Any], list[Any], list[dict[str, Any]]]: + """Validate successful init and build runtime metadata lists.""" + if any(stage_client is None for stage_client in stage_clients): + raise RuntimeError("Stage initialization completed with missing stage clients") + + initialized_stage_clients = [stage_client for stage_client in stage_clients if stage_client is not None] + default_sampling_params_list = [stage_client.default_sampling_params for stage_client in initialized_stage_clients] + stage_metadata = [ + { + "final_output": stage_client.final_output, + "final_output_type": stage_client.final_output_type, + "stage_type": stage_client.stage_type, + } + for stage_client in initialized_stage_clients + ] + + if not isinstance(input_processor, InputProcessor): + has_llm_stage = any(metadata.get("stage_type") != "diffusion" for metadata in stage_metadata) + if has_llm_stage: + raise RuntimeError("Failed to initialize stage-0 InputProcessor for LLM pipeline") + + return initialized_stage_clients, default_sampling_params_list, stage_metadata + + +def cleanup_failed_stage_initialization( + stage_clients: list[Any | None], + started_llm_stages: list[StartedLlmStage], +) -> None: + """Shutdown attached stages and close any launched-but-unattached engines.""" + for cleanup_stage_id, stage_client in reversed(list(enumerate(stage_clients))): + if stage_client is None: + continue + try: + stage_client.shutdown() + except Exception as cleanup_error: + logger.warning( + "[stage_init] Failed to shutdown initialized stage %s after init failure: %s", + cleanup_stage_id, + cleanup_error, + ) + + for started in reversed(started_llm_stages): + if stage_clients[started.stage_id] is not None: + continue + close_started_llm_stage(started) diff --git a/vllm_omni/entrypoints/__init__.py b/vllm_omni/entrypoints/__init__.py index 8d0ee51a51..d0830df96d 100644 --- a/vllm_omni/entrypoints/__init__.py +++ b/vllm_omni/entrypoints/__init__.py @@ -3,11 +3,6 @@ """ vLLM-Omni entrypoints module. - -Provides high-level interfaces for running omni models including: -- AsyncOmni: Async orchestrator for multi-stage LLM pipelines -- AsyncOmniDiffusion: Async interface for diffusion model inference -- Omni: Unified entrypoint that auto-selects between LLM and Diffusion """ from vllm_omni.entrypoints.async_omni import AsyncOmni diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 474834d16a..3ba2c4a7ef 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -1,108 +1,64 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +AsyncOmni - Refactored async orchestrator using AsyncOmniEngine. + +This is the new implementation that uses AsyncOmniEngine (which manages +StageEngineCoreClient instances) instead of OmniStage with worker processes. +""" + +from __future__ import annotations + import asyncio -import copy import time -import weakref -from collections.abc import AsyncGenerator, Callable, Iterable, Sequence -from typing import Any, TypeVar +from collections.abc import AsyncGenerator, Iterable, Sequence +from typing import TYPE_CHECKING, Any -from vllm.config import VllmConfig -from vllm.inputs.preprocess import InputPreprocessor +from vllm.engine.protocol import EngineClient from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import PoolingRequestOutput from vllm.plugins.io_processors import get_io_processor -from vllm.sampling_params import SamplingParams -from vllm.tokenizers import TokenizerLike +from vllm.pooling_params import PoolingParams +from vllm.tasks import SupportedTask from vllm.v1.engine.exceptions import EngineDeadError -from vllm_omni.config import OmniModelConfig -from vllm_omni.diffusion.data import DiffusionParallelConfig -from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length, try_send_via_connector -from vllm_omni.distributed.ray_utils.utils import try_close_ray -from vllm_omni.engine.input_processor import OmniInputProcessor -from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.entrypoints.client_request_state import ClientRequestState -from vllm_omni.entrypoints.omni import OmniBase -from vllm_omni.entrypoints.omni_stage import OmniStage -from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType -from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load -from vllm_omni.entrypoints.utils import ( - get_final_stage_id_for_e2e, -) -from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams - -# Internal imports (our code) -from vllm_omni.lora.request import LoRARequest -from vllm_omni.metrics import OrchestratorAggregator +from vllm_omni.entrypoints.omni_base import OmniBase +from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.outputs import OmniRequestOutput -_R = TypeVar("_R") - -logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm.inputs.preprocess import InputPreprocessor + from vllm.tokenizers import TokenizerLike + from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams -def _weak_close_cleanup_async( - stage_list, stage_in_queues, stage_out_queues, ray_pg, output_handler, zmq_ctx=None, inline_engine=None -): - """Weak reference cleanup function for AsyncOmni instances.""" - if inline_engine is not None: - try: - inline_engine.close() - except Exception as e: - logger.warning("Failed to close inline diffusion engine: %s", e) - if stage_list: - for q in stage_in_queues: - try: - q.put_nowait(SHUTDOWN_TASK) - except Exception as e: - logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") - close_fn = getattr(q, "close", None) - if callable(close_fn): - close_fn() - for q in stage_out_queues: - close_fn = getattr(q, "close", None) - if callable(close_fn): - close_fn() - for stage in stage_list: - try: - stage.stop_stage_worker() - except Exception as e: - logger.warning(f"Failed to stop stage worker: {e}") - try_close_ray(ray_pg) - # Cancel output handler - if output_handler is not None: - output_handler.cancel() - if zmq_ctx is not None: - zmq_ctx.term() +logger = init_logger(__name__) +_FINAL_OUTPUT_IDLE_SLEEP_S = 0.001 -class AsyncOmni(OmniBase): - """Asynchronous unified entry point supporting multi-stage pipelines for LLM and Diffusion models. +class AsyncOmni(EngineClient, OmniBase): + """Asynchronous unified entry point for multi-stage pipelines using AsyncOmniEngine. - Similar to the Omni class, but provides an asynchronous interface supporting - asynchronous LLM and Diffusion models. + This is the refactored version that uses AsyncOmniEngine instead of + OmniStage workers. It provides the same interface as AsyncOmni but with + a cleaner architecture. Args: model: Model name or path to load. - **kwargs: Arbitrary keyword arguments. + **kwargs: Additional keyword arguments. - stage_configs_path: Optional path to YAML file containing stage - configurations. If None, configurations are loaded from the model. - - log_stats: Whether to enable statistics logging - be written to files with stage-specific suffixes. - - stage_init_timeout: Per-stage init watchdog (seconds). Measured from - when the previous stage finished (possibly a prior Omni run with GPU - reuse/overlap) to when the current stage starts to initialize. - - shm_threshold_bytes: Threshold in bytes for using shared memory - for IPC. Objects larger than this threshold will use shared memory. - - worker_backend: Backend for worker processes. Default is "multi_process". - - ray_address: Address of Ray cluster for Ray backend, if using Ray backend. - - batch_timeout: Timeout in seconds for batching requests within a stage - - init_timeout: Timeout in seconds for waiting for all stages to initialize + configurations. If None, configurations are resolved from model + pipeline factory. + - log_stats: Whether to enable statistics logging. + - stage_init_timeout: Timeout for per-stage initialization. + - init_timeout: Total timeout for orchestrator startup. + - async_chunk: Whether to enable async chunk mode. + - output_modalities: Requested output modalities. - Additional keyword arguments passed to stage engines. Example: - >>> async_llm = AsyncOmni(model="Qwen/Qwen2.5-Omni-7B") - >>> async for output in async_llm.generate( + >>> async_omni = AsyncOmni(model="Qwen/Qwen2.5-Omni-7B") + >>> async for output in async_omni.generate( ... prompt="Hello", ... request_id="req-1", ... sampling_params_list=[SamplingParams(), SamplingParams()] @@ -110,212 +66,66 @@ class AsyncOmni(OmniBase): ... print(output) """ - def __init__(self, model: str, **kwargs: dict[str, Any]) -> None: - # Pause/resume control attributes + def __init__(self, model: str, **kwargs: Any) -> None: + OmniBase.__init__(self, model=model, **kwargs) self._pause_cond: asyncio.Condition = asyncio.Condition() self._paused: bool = False - - # Sleep mode tracking self._is_sleeping: bool = False + self.final_output_task: asyncio.Task | None = None - # Request state tracking - self.request_states: dict[str, ClientRequestState] = {} - self.output_handler: asyncio.Task | None = None - - # RPC results storage: {stage_id: {rpc_id: result}} - # Used to avoid race condition between output_handler and collective_rpc - self._rpc_results: dict[int, dict[str, dict[str, Any]]] = {} - - # CFG companion → parent request ID mapping for output routing - self._companion_to_parent: dict[str, str] = {} - - super().__init__(model, **kwargs) - - # Register weak reference cleanup (called on garbage collection) - self._weak_finalizer = weakref.finalize( - self, - _weak_close_cleanup_async, - self.stage_list, - self._stage_in_queues, - self._stage_out_queues, - self._ray_pg, - self.output_handler, - self._zmq_ctx, - getattr(self, "_inline_engine", None), - ) + self.config_path = self.engine.config_path + self.stage_configs = self.engine.stage_configs + self.tts_max_instructions_length = kwargs.get("tts_max_instructions_length", None) + self.input_processor = self.engine.input_processor - async def get_supported_tasks(self) -> set[str]: - """Return supported tasks based on stage output modalities and capabilities.""" - tasks: set[str] = set() - if "text" in self.output_modalities or any(stage.is_comprehension for stage in self.stage_list): - tasks.add("generate") - if "audio" in self.output_modalities: - tasks.add("speech") - return tasks - - def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[str, Any]: - """Create default diffusion stage configuration.""" - # TODO: here is different from the Omni class. We should merge the two in the future. - cache_backend = kwargs.get("cache_backend", "none") - cache_config = self._normalize_cache_config(cache_backend, kwargs.get("cache_config", None)) - - devices = "0" - if "parallel_config" in kwargs: - parallel_config = kwargs["parallel_config"] - num_devices = kwargs["parallel_config"].world_size - for i in range(1, num_devices): - devices += f",{i}" - else: - ulysses_degree = kwargs.get("ulysses_degree") or 1 - ring_degree = kwargs.get("ring_degree") or 1 - sequence_parallel_size = kwargs.get("sequence_parallel_size") - tensor_parallel_size = kwargs.get("tensor_parallel_size") or 1 - enable_expert_parallel = kwargs.get("enable_expert_parallel") or False - cfg_parallel_size = kwargs.get("cfg_parallel_size") or 1 - vae_patch_parallel_size = kwargs.get("vae_patch_parallel_size") or 1 - use_hsdp = kwargs.get("use_hsdp", False) - hsdp_shard_size = kwargs.get("hsdp_shard_size", -1) - hsdp_replicate_size = kwargs.get("hsdp_replicate_size", 1) - if sequence_parallel_size is None: - sequence_parallel_size = ulysses_degree * ring_degree - - # Calculate num_devices: consider standalone HSDP - other_parallel_size = sequence_parallel_size * tensor_parallel_size * cfg_parallel_size - if use_hsdp and other_parallel_size == 1 and hsdp_shard_size > 0: - # Standalone HSDP: num_devices is determined by HSDP dimensions - num_devices = hsdp_shard_size * hsdp_replicate_size - else: - num_devices = other_parallel_size - - for i in range(1, num_devices): - devices += f",{i}" - parallel_config = DiffusionParallelConfig( - pipeline_parallel_size=1, - data_parallel_size=1, - tensor_parallel_size=tensor_parallel_size, - enable_expert_parallel=enable_expert_parallel, - sequence_parallel_size=sequence_parallel_size, - ulysses_degree=ulysses_degree, - ring_degree=ring_degree, - cfg_parallel_size=cfg_parallel_size, - vae_patch_parallel_size=vae_patch_parallel_size, - use_hsdp=use_hsdp, - hsdp_shard_size=hsdp_shard_size, - hsdp_replicate_size=hsdp_replicate_size, - ) - default_stage_cfg = [ - { - "stage_id": 0, - "stage_type": "diffusion", - "runtime": { - "process": True, - "devices": devices, - "max_batch_size": 1, - }, - "engine_args": { - "parallel_config": parallel_config, - "model_class_name": kwargs.get("model_class_name", None), - "vae_use_slicing": kwargs.get("vae_use_slicing", False), - "vae_use_tiling": kwargs.get("vae_use_tiling", False), - "cache_backend": cache_backend, - "cache_config": cache_config, - "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False), - "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), - "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False), - "enforce_eager": kwargs.get("enforce_eager", False), - "diffusion_load_format": kwargs.get("diffusion_load_format", "default"), - "custom_pipeline_args": kwargs.get("custom_pipeline_args", None), - "quantization": kwargs.get("quantization", None), - "worker_extension_cls": kwargs.get("worker_extension_cls", None), - "enable_sleep_mode": kwargs.get("enable_sleep_mode", False), - "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True), - "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4), - }, - "final_output": True, - "final_output_type": "image", - } - ] - default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" - return default_stage_cfg - - def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str, Any]) -> None: - # Store vllm_config received from worker process (may be None for diffusion stages) - vllm_config = result.get("vllm_config") - if vllm_config is not None: - stage.set_vllm_config(vllm_config) - tokenizer = result.get("tokenizer") - if tokenizer is not None: - stage.set_tokenizer(tokenizer) - is_tracing_enabled = result.get("is_tracing_enabled") - if is_tracing_enabled is not None: - stage.set_is_tracing_enabled(is_tracing_enabled) - super()._process_stage_ready(stage, stage_id, result) - - def _wait_for_stages_ready(self, timeout: int = 120) -> None: - """Wait for all stages to report readiness.""" - super()._wait_for_stages_ready(timeout) - for stage in self.stage_list: - if stage.vllm_config is not None and stage.tokenizer is not None: - try: - vllm_config = stage.vllm_config - # Initialize input_processor - # OMNI: OmniInputProcessor creates tokenizer internally from vllm_config - self.input_processor = OmniInputProcessor( - vllm_config=vllm_config, - ) - # Initialize model_config - self.model_config = vllm_config.model_config - # Initialize io_processor - io_processor_plugin = self.model_config.io_processor_plugin - self.io_processor = get_io_processor(vllm_config, io_processor_plugin) - - logger.info( - f"[{self._name}] Initialized input_processor, " - f"io_processor, and model_config from stage-{stage.stage_id}", - ) - break - except Exception as e: - logger.warning( - f"[{self._name}] Failed to initialize processors from stage-{stage.stage_id}: {e}", - ) - # If no LLM stage found, set processors to None - if not hasattr(self, "input_processor") or self.input_processor is None: - logger.warning( - f"[{self._name}] No LLM stage found, processors will not be available. " - "This may cause issues with OpenAIServingModels." - ) - self.input_processor = None + stage_index = self._get_comprehension_stage_index() + if stage_index is None: self.io_processor = None - self.model_config = None - - def _setup_rpc_result_checkers(self) -> None: - """Override base class to use async-friendly RPC result checkers. - - In the async path the output_handler task drains the output queues - and stashes collective_rpc results into ``_rpc_results``. The - checker therefore only needs to look there (no queue draining). - """ - for stage in self.stage_list: - sid = stage.stage_id + else: + vllm_config = self.engine.stage_vllm_configs[stage_index] + io_processor_plugin = vllm_config.model_config.io_processor_plugin + self.io_processor = get_io_processor(vllm_config, io_processor_plugin) + + def _get_comprehension_stage_index(self) -> int | None: + fallback_idx: int | None = None + for idx, stage_client in enumerate(self.engine.stage_clients): + stage_vllm_config = self.engine.stage_vllm_configs[idx] + if stage_vllm_config is None: + continue + if fallback_idx is None: + fallback_idx = idx + if stage_client.is_comprehension: + return idx + return fallback_idx - def make_rpc_checker(stage_id: int): - def rpc_checker(rpc_id: str) -> dict[str, Any] | None: - if stage_id in self._rpc_results and rpc_id in self._rpc_results[stage_id]: - return self._rpc_results[stage_id].pop(rpc_id) - return None + @property + def renderer(self): + """Return the renderer from the engine input processor when available.""" + if self.input_processor is None: + return None + return self.input_processor.renderer - return rpc_checker + @property + def vllm_config(self): + """Return the vLLM config for the comprehension stage when present.""" + stage_index = self._get_comprehension_stage_index() + if stage_index is None: + return None + return self.engine.stage_vllm_configs[stage_index] - stage._rpc_result_checker = make_rpc_checker(sid) + async def get_vllm_config(self) -> Any: + """Compatibility helper for call sites expecting async vllm config access.""" + return self.vllm_config - def shutdown(self): - """Shutdown, cleaning up the background proc and IPC. + @property + def model_config(self): + """Return the model config for the comprehension stage when present.""" + vllm_config = self.vllm_config + if vllm_config is None: + return None + return vllm_config.model_config - Alias for close() method. Cleans up all stage processes - and inter-process communication resources. - """ - if hasattr(self, "_weak_finalizer"): - self._weak_finalizer() + # ==================== Generate Method ==================== async def generate( self, @@ -327,11 +137,8 @@ async def generate( ) -> AsyncGenerator[OmniRequestOutput, None]: """Generate outputs for the given prompt asynchronously. - Coordinates multi-stage pipeline through YAML configuration. - Each stage will use AsyncOmniLLM or AsyncOmniDiffusion based on stage_type. - Processes the prompt through all stages in the pipeline and yields - outputs as they become available. Each stage uses its corresponding - sampling parameters from the sampling_params_list. + Coordinates multi-stage pipeline execution. Processes the prompt through + all stages in the pipeline and yields outputs as they become available. Args: prompt: Prompt to process. Can be a text string, token IDs, @@ -344,776 +151,442 @@ async def generate( Yields: OmniRequestOutput objects as they are produced by each stage. - Each output contains the stage_id, final_output_type, and - the request_output from that stage. Raises: ValueError: If sampling_params_list has incorrect length. """ - # Wait until generation is resumed if the engine is paused. + # Wait until generation is resumed if the engine is paused async with self._pause_cond: await self._pause_cond.wait_for(lambda: not self._paused) - if self._inline_diffusion: - async for output in self._generate_inline(prompt, request_id, sampling_params_list, output_modalities): - yield output - return + logger.debug(f"[AsyncOmni] generate() called for request {request_id}") - logger.debug(f"[{self._name}] generate() called") try: - # Start output handler on the first call to generate() - self._run_output_handler() - - # TODO: lora_request, trace_headers, priority are not supported yet - if sampling_params_list is None: - sampling_params_list = self.default_sampling_params_list - - if len(sampling_params_list) != len(self.stage_list): - raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") - - # Orchestrator keeps stage objects for input derivation - num_stages = len(self.stage_list) - # Track per-request start time for end-to-end timing - _req_start_ts: dict[int, float] = {} - _wall_start_ts: float = time.time() - # _last_finish_ts: float = _wall_start_ts - - # Determine the final stage for E2E stats (highest stage_id with - # final_output=True; fallback to last stage) - final_stage_id_for_e2e = get_final_stage_id_for_e2e( - output_modalities, self.output_modalities, self.stage_list - ) + # Start final output dispatcher on the first call to generate() + self._final_output_handler() + + sampling_params_list = self.resolve_sampling_params_list(sampling_params_list) - # Metrics/aggregation helper - metrics = OrchestratorAggregator( - num_stages=num_stages, - log_stats=self.log_stats, - wall_start_ts=_wall_start_ts, - final_stage_id_for_e2e=final_stage_id_for_e2e, + # Track per-request metrics + wall_start_ts = time.time() + req_start_ts: dict[str, float] = {} + + # Determine the final stage for E2E stats + final_stage_id_for_e2e = self._compute_final_stage_id(output_modalities) + + metrics = OrchestratorMetrics( + self.num_stages, + self.log_stats, + wall_start_ts, + final_stage_id_for_e2e, ) req_state = ClientRequestState(request_id) req_state.metrics = metrics self.request_states[request_id] = req_state - # Ensure modalities is in the prompt dict for CFG expansion - # (offline path includes it; online serving passes it separately) - if isinstance(prompt, dict) and output_modalities and "modalities" not in prompt: - prompt["modalities"] = output_modalities - - # CFG companion tracking (prompt expansion + lifecycle management) - cfg = CfgCompanionTracker( - prompt_expand_func=getattr(self.stage_list[0], "prompt_expand_func", None), - stage0_sampling_params=sampling_params_list[0], - ) - expanded_companions = cfg.expand_prompts({request_id: prompt}) - - sp0: SamplingParams = sampling_params_list[0] # type: ignore[index] - task = { - "request_id": request_id, - "engine_inputs": prompt, - "sampling_params": sp0, - } - self.stage_list[0].submit(task) - - # Submit CFG companion requests to stage-0 - if cfg.is_active: - for companion_id, companion_prompt in expanded_companions: - self._companion_to_parent[companion_id] = request_id - companion_task = { - "request_id": companion_id, - "engine_inputs": companion_prompt, - "sampling_params": cfg.stage0_sampling_params, - } - self.stage_list[0].submit(companion_task) - - metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() - _req_start_ts[request_id] = time.time() - logger.info( - f"[{self._name}] Entering scheduling loop: stages={num_stages}, final_stage={final_stage_id_for_e2e}" + # Add request to stage 0 (Orchestrator handles all stage transitions) + await self.engine.add_request_async( + request_id=request_id, + prompt=prompt, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id_for_e2e, ) - if self.async_chunk: - stage_queues = {stage_id: asyncio.Queue() for stage_id in range(num_stages)} - req_state.stage_queues = stage_queues - async for output in self._process_async_results( - request_id, - prompt, - sampling_params_list, - req_state, - metrics, - final_stage_id_for_e2e, - ): - yield output - else: - async for output in self._process_sequential_results( - request_id, - req_state, - metrics, - final_stage_id_for_e2e, - sampling_params_list, - prompt, - cfg=cfg, - ): - yield output - - logger.debug(f"[{self._name}] Request {request_id} finalized at stage-{final_stage_id_for_e2e}") - try: - # Finalize E2E metrics if not already done - metrics.on_finalize_request( - final_stage_id_for_e2e, - request_id, - _req_start_ts.get(request_id, _wall_start_ts), - ) + submit_ts = time.time() + req_state.metrics.stage_first_ts[0] = submit_ts + req_start_ts[request_id] = submit_ts + + # Process results based on mode + # Both sequential and async_chunk modes read the same message stream + # from Orchestrator; stage-transfer behavior differs inside + # Orchestrator._route_output(). + async for output in self._process_orchestrator_results( + request_id, + metrics, + final_stage_id_for_e2e, + req_start_ts, + wall_start_ts, + ): + yield output + + logger.debug(f"[AsyncOmni] Request {request_id} completed") + + self._log_summary_and_cleanup(request_id) - logger.debug(f"[{self._name}] All requests completed") - # Summarize and print stats - metrics.build_and_log_summary() - except Exception as e: - logger.exception(f"[{self._name}] Request {request_id} Failed to finalized/build/log summary: {e}") - finally: - self.request_states.pop(request_id, None) - if cfg.is_active: - for cid in cfg.get_companion_request_ids(request_id).values(): - self._companion_to_parent.pop(cid, None) except (asyncio.CancelledError, GeneratorExit): await self.abort(request_id) - logger.info("[AsyncOrchestrator] Request %s aborted.", request_id) + logger.info(f"[AsyncOmni] Request {request_id} aborted.") raise - async def _generate_inline( + async def encode( self, - prompt: OmniPromptType, + prompt: Any, + pooling_params: PoolingParams, request_id: str, - sampling_params_list: Sequence[OmniSamplingParams] | None = None, - output_modalities: list[str] | None = None, - ) -> AsyncGenerator[OmniRequestOutput, None]: - """Generate using inline diffusion engine (no stage worker subprocess). - - Eliminates Hop3 IPC overhead by running OmniDiffusion directly in the - orchestrator process. The blocking generate() call is offloaded to a - thread executor so the asyncio event loop remains responsive. + lora_request: LoRARequest | None = None, + trace_headers: dict[str, str] | None = None, + priority: int = 0, + tokenization_kwargs: dict[str, Any] | None = None, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """EngineClient.encode() stub. + + Omni pipeline currently exposes only generate() API at orchestrator level. """ - _wall_start_ts = time.time() - - if sampling_params_list is None: - sampling_params_list = self.default_sampling_params_list - sp0 = sampling_params_list[0] - - stage = self.stage_list[0] - final_stage_id_for_e2e = 0 + raise NotImplementedError("AsyncOmni.encode is not implemented.") - metrics = OrchestratorAggregator( - num_stages=1, - log_stats=self.log_stats, - wall_start_ts=_wall_start_ts, - final_stage_id_for_e2e=final_stage_id_for_e2e, - ) - metrics.stage_first_ts[0] = time.time() - - logger.info( - "[%s] Inline diffusion generate for request %s", - self._name, - request_id, - ) + # ==================== Processing Methods ==================== - try: - loop = asyncio.get_running_loop() - results = await loop.run_in_executor( - None, - self._inline_engine.generate, - prompt, - sp0, - [request_id], - ) + async def _process_orchestrator_results( + self, + request_id: str, + metrics: OrchestratorMetrics, + final_stage_id_for_e2e: int, + req_start_ts: dict[str, float], + wall_start_ts: float, + ) -> AsyncGenerator[OmniRequestOutput, None]: + """Read results from the Orchestrator (via the request's asyncio.Queue) + and yield OmniRequestOutput objects. - for result in results: - images = getattr(result, "images", None) or [] - finished = getattr(result, "finished", True) + The Orchestrator handles all stage-to-stage transfers. This method + only processes final outputs that arrive on the per-request queue. + """ + req_state = self.request_states.get(request_id) + if req_state is None: + return - output_to_yield = OmniRequestOutput( - stage_id=0, - final_output_type=stage.final_output_type, - request_output=result, - images=images, - finished=finished, - ) + while True: + result = await req_state.queue.get() - metrics.stage_last_ts[0] = time.time() - yield output_to_yield + stage_id = result.get("stage_id", 0) - try: - metrics.on_finalize_request( - final_stage_id_for_e2e, + # Check for errors + if "error" in result: + logger.error( + "[AsyncOmni] Orchestrator error for req=%s stage-%s: %s", request_id, - _wall_start_ts, - ) - metrics.build_and_log_summary() - except Exception as e: - logger.exception( - "[%s] Failed to finalize inline metrics: %s", - self._name, - e, + stage_id, + result["error"], ) - - except (asyncio.CancelledError, GeneratorExit): - logger.info( - "[%s] Inline request %s cancelled.", - self._name, - request_id, - ) - raise - except Exception as e: - logger.exception( - "[%s] Inline diffusion failed for request %s: %s", - self._name, - request_id, - e, + raise RuntimeError(result) + + # Process the result (constructs OmniRequestOutput) + output_to_yield = self._process_single_result( + result, + stage_id, + metrics, + req_start_ts, + wall_start_ts, + final_stage_id_for_e2e, ) - raise - async def _process_async_results( - self, - request_id: str, - prompt: Any, - sampling_params_list: list[SamplingParams], - req_state: ClientRequestState, - metrics: OrchestratorAggregator, - final_stage_id_for_e2e: int, - ) -> AsyncGenerator[OmniRequestOutput, None]: - all_stages_finished = {stage_id: False for stage_id in range(final_stage_id_for_e2e + 1)} - submit_flag = True - _loop_iter = 0 - _last_progress_ts = time.time() - while not all(all_stages_finished.values()): - _loop_iter += 1 - for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): - if all_stages_finished[stage_id]: - continue - try: - result = req_state.stage_queues[stage_id].get_nowait() - except asyncio.QueueEmpty: - await asyncio.sleep(0.001) - continue - _last_progress_ts = time.time() - engine_outputs, finished, output_to_yield = self._process_single_result( - result, - stage, - stage_id, - metrics, - ) - if submit_flag and stage_id == 0: - submit_flag = False - prompt_token_ids = getattr(engine_outputs, "prompt_token_ids", None) - if prompt_token_ids is None: - prompt_token_ids = [] - engine_input = copy.deepcopy(prompt) - try: - next_prompt_len = max(1, compute_talker_prompt_ids_length(prompt_token_ids)) - except Exception: - raise - engine_input["prompt_token_ids"] = [0] * next_prompt_len - engine_input["multi_modal_data"] = engine_input["mm_processor_kwargs"] = None - for _mm_key in ("mm_kwargs", "mm_hashes", "mm_placeholders", "multi_modal_uuids"): - engine_input.pop(_mm_key, None) - if engine_input.get("type") == "multimodal": - engine_input["type"] = "token" - for i in range(1, len(self.stage_list)): - task = { - "request_id": request_id, - "engine_inputs": engine_input, - "sampling_params": sampling_params_list[i], - } - self.stage_list[i].submit(task) - metrics.stage_first_ts[i] = time.time() - all_stages_finished[stage_id] = finished - - if output_to_yield: - yield output_to_yield - - async def _process_sequential_results( - self, - request_id: str, - req_state: ClientRequestState, - metrics: OrchestratorAggregator, - final_stage_id_for_e2e: int, - sampling_params_list: list[SamplingParams], - prompt: Any, - cfg: CfgCompanionTracker | None = None, - ) -> AsyncGenerator[OmniRequestOutput, None]: - for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): - cfg_stage0 = stage_id == 0 and cfg is not None and cfg.is_active - finished = False - - while True: - if finished and ( - not cfg_stage0 or cfg.all_companions_done(request_id) or cfg.is_parent_failed(request_id) - ): - break - - result = await req_state.queue.get() - - if cfg is not None and cfg.is_companion(result.get("request_id", "")): - if cfg_stage0: - rid = result.get("request_id") - if "error" in result: - cfg.on_companion_error(rid) - else: - cfg.on_companion_completed(rid) - continue - - engine_outputs, finished, output_to_yield = self._process_single_result( - result, - stage, + if output_to_yield: + logger.debug( + "[AsyncOmni] req=%s stage-%s yielding final_output_type=%s", + request_id, stage_id, - metrics, - ) - if output_to_yield: - yield output_to_yield - if not isinstance(engine_outputs, list): - engine_outputs = [engine_outputs] - stage.set_engine_outputs(engine_outputs) - # Forward to next stage if there is one - next_stage_id = stage_id + 1 - if next_stage_id <= final_stage_id_for_e2e: - next_stage: OmniStage = self.stage_list[next_stage_id] - # Derive inputs for the next stage, record postprocess time - with metrics.stage_postprocess_timer(stage_id, request_id): - next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) - sp_next: SamplingParams = sampling_params_list[next_stage_id] - - if cfg is not None and cfg.is_active and not cfg.is_parent_failed(request_id): - if isinstance(sp_next, OmniDiffusionSamplingParams): - sp_next = copy.deepcopy(sp_next) - sp_next.cfg_kv_request_ids = cfg.get_companion_request_ids(request_id) - logger.info( - "Attaching cfg_kv_request_ids=%s to request %s", - sp_next.cfg_kv_request_ids, - request_id, - ) - - # Check if we have a connector for this edge - connector_key = (str(stage_id), str(next_stage_id)) - connector = self.connectors.get(connector_key) - - sent_via_connector = False - if connector: - sent_via_connector = try_send_via_connector( - connector=connector, - stage_id=stage_id, - next_stage_id=next_stage_id, - req_id=request_id, - next_inputs=next_inputs, - sampling_params=sp_next, - original_prompt=prompt, - next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, - metrics=metrics, - ) - - if not sent_via_connector: - # Fallback logic removed as we now enforce connector usage. - # If no connector is found or send fails, we log an error and raise, - # because continuing would cause the request to be silently dropped - # and the orchestrator to hang waiting for completion. - error_msg = ( - f"[{self._name}] Failed to send request {request_id} to stage-{next_stage_id} via connector. " - "Configure a connector for this edge or inspect connector logs for details." - ) - logger.error(error_msg) - raise RuntimeError(error_msg) - logger.debug(f"[{self._name}] Forwarded request {request_id} to stage-{next_stage_id}") - else: - logger.debug(f"[{self._name}] Request {request_id} fully completed") - - def _process_single_result( - self, - result: dict[str, Any], - stage: OmniStage, - stage_id: int, - metrics: OrchestratorAggregator, - ) -> tuple[Any, bool, OmniRequestOutput | None]: - """ - Process a single result dictionary from a stage. - Returns: - engine_outputs: The decoded outputs. - finished: Whether the stage processing is finished for this request. - output_to_yield: An OmniRequestOutput to yield, or None. - """ - req_id = result.get("request_id") - if "error" in result: - logger.error( - f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}", - ) - raise RuntimeError(result) - - engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") - - if isinstance(engine_outputs, list): - engine_outputs = engine_outputs[0] - - finished = engine_outputs.finished - - output_to_yield = None - - if getattr(stage, "final_output", False): - # Construct output to yield - images = [] - if stage.final_output_type == "image": - if isinstance(engine_outputs, OmniRequestOutput) and engine_outputs.images: - images = engine_outputs.images - elif hasattr(engine_outputs, "images") and engine_outputs.images: - images = engine_outputs.images - - if stage.final_output_type == "image": - # Propagate custom_output from inner diffusion output - custom_output = {} - if isinstance(engine_outputs, OmniRequestOutput): - custom_output = engine_outputs.custom_output or {} - output_to_yield = OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, - request_output=engine_outputs, - images=images, - finished=finished, - _custom_output=custom_output, + getattr(output_to_yield, "final_output_type", None), ) - else: - output_to_yield = OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, - request_output=engine_outputs, - finished=finished, - ) - # Mark last output time - metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) - - metrics.process_stage_metrics( - result=result, - stage_type=stage.stage_type, - stage_id=stage_id, - req_id=req_id, - engine_outputs=engine_outputs, - finished=finished, - final_output_type=stage.final_output_type, - output_to_yield=output_to_yield, - ) + yield output_to_yield - logger.debug( - f"[{self._name}] Stage-{stage_id} completed request {req_id}; forwarding or finalizing", - ) + # The Orchestrator sets "finished" when the final stage is done + if result.get("finished"): + break - return engine_outputs, finished, output_to_yield + # ==================== Output Handler ==================== - def _run_output_handler(self) -> None: - if self.output_handler is not None: + def _final_output_handler(self) -> None: + """Start the final output handler if not already running. + + This handler reads messages from the Orchestrator output queue and + routes them to per-request asyncio.Queues. + """ + if self.final_output_task is not None: return - stage_list = self.stage_list - request_states = self.request_states - companion_to_parent = self._companion_to_parent + engine = self.engine - async def output_handler(): + async def _final_output_loop(): + """Background coroutine that dispatches final outputs to request queues.""" try: while True: - idle = True - for stage_id, stage in enumerate(stage_list): - result = stage.try_collect() - if result is None: - continue - idle = False - if result.get("type") == "stage_ready": - # Only happens when stage is initialized slower than expected, - # so we wait for a short time and try again - await asyncio.sleep(0.05) - continue - # Handle collective_rpc results separately to avoid - # race condition with the polling in collective_rpc - if result.get("type") == "collective_rpc_result": - rpc_id = result.get("rpc_id") - if rpc_id: - if stage_id not in self._rpc_results: - self._rpc_results[stage_id] = {} - self._rpc_results[stage_id][rpc_id] = result - continue - req_id = result.get("request_id") - req_state = request_states.get(req_id) - if req_state is None: - parent_id = companion_to_parent.get(req_id) - if parent_id is not None: - req_state = request_states.get(parent_id) - if req_state is None: - logger.debug( - f"[{self._name}] Request may have been aborted; \ - dropping output for req {req_id} at stage-{stage_id}" - ) - continue - if hasattr(req_state, "stage_queues") and stage_id in req_state.stage_queues: - await req_state.stage_queues[stage_id].put(result) - else: - # Fallback to old behavior for compatibility - await req_state.queue.put(result) - req_state.stage_id = stage_id - if idle: - await asyncio.sleep(0.001) # Avoid CPU overload when idle - else: - await asyncio.sleep(0) + msg = await engine.try_get_output_async() + if msg is None: + await asyncio.sleep(_FINAL_OUTPUT_IDLE_SLEEP_S) + continue + + should_continue, _, stage_id, req_state = self._handle_output_message(msg) + if should_continue: + continue + + req_state.stage_id = stage_id + + # Route to the per-request queue + await req_state.queue.put(msg) + + except asyncio.CancelledError: + raise except Exception as e: - logger.exception("AsyncOmni output_handler failed.") - for req_state in request_states.values(): - error_msg = {"request_id": req_state.request_id, "error": str(e)} - # Send error to all stage queues - if hasattr(req_state, "stage_queues"): - for queue in req_state.stage_queues.values(): - await queue.put(error_msg) - else: - await req_state.queue.put(error_msg) + logger.exception("[AsyncOmni] final_output_loop failed.") + for req_state in list(self.request_states.values()): error_msg = {"request_id": req_state.request_id, "error": str(e)} - self.output_handler = None # Make possible for restart + await req_state.queue.put(error_msg) + self.final_output_task = None - self.output_handler = asyncio.create_task(output_handler()) + self.final_output_task = asyncio.create_task(_final_output_loop()) + logger.debug("[AsyncOmni] Final output handler started") - @property - def is_running(self) -> bool: - if self._inline_diffusion: - return self._inline_engine is not None - return len(self._stage_in_queues) > 0 + # ==================== Control Methods ==================== - @property - def is_stopped(self) -> bool: - return self.errored + async def collective_rpc( + self, + method: str, + timeout: float | None = None, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + stage_ids: list[int] | None = None, + ) -> list[Any]: + """Execute a best-effort control RPC on selected stages. - @property - def errored(self) -> bool: - return not self.is_running + Unsupported stages currently return a TODO-style result dict instead of + failing the entire call. This keeps AsyncOmni usable while the orchestrator + control plane is still being filled out. + """ + results = await self.engine.collective_rpc_async( + method=method, + timeout=timeout, + args=args, + kwargs=kwargs, + stage_ids=stage_ids, + ) - @property - def _name(self) -> str: - return "AsyncOrchestrator" + unsupported_stage_ids: list[int] = [] + effective_stage_ids = stage_ids or list(range(len(results))) + for index, result in enumerate(results): + if isinstance(result, dict) and result.get("todo"): + unsupported_stage_ids.append(effective_stage_ids[index]) - @property - def is_async(self) -> bool: - return True + if unsupported_stage_ids: + logger.warning( + "[AsyncOmni] collective_rpc(%s) has TODO support on stage(s): %s", + method, + unsupported_stage_ids, + ) - @property - def dead_error(self) -> BaseException: - return EngineDeadError() + return results + + @staticmethod + def _coerce_stage_bool(result: Any) -> bool: + """Reduce a stage RPC result to a boolean. + + Some stage RPCs may return worker-level lists like ``[True]``; + diffusion wrappers usually return a plain bool. + """ + if isinstance(result, list): + return all(bool(item) for item in result) + return bool(result) async def abort(self, request_id: str | Iterable[str]) -> None: - if self._inline_diffusion: - if self._inline_engine is not None: - self._inline_engine.engine.abort(request_id) - return None - abort_task = {"type": OmniStageTaskType.ABORT, "request_id": request_id} - for stage in self.stage_list: - stage.submit(abort_task) - return None - - async def get_vllm_config(self) -> VllmConfig: - for stage in self.stage_list: - if stage.is_comprehension: - # Use the vllm_config received from worker process - if stage.vllm_config is not None: - return stage.vllm_config - return None - - async def get_model_config(self) -> OmniModelConfig: - for stage in self.stage_list: - if stage.is_comprehension: - # Use the vllm_config received from worker process - if stage.vllm_config is not None: - return stage.vllm_config.model_config - return None + """Abort request(s) via the Orchestrator.""" + request_ids = [request_id] if isinstance(request_id, str) else list(request_id) + await self.engine.abort_async(request_ids) + for req_id in request_ids: + self.request_states.pop(req_id, None) + if self.log_stats: + logger.info("[AsyncOmni] Aborted request(s) %s", ",".join(request_ids)) - async def get_input_preprocessor(self) -> InputPreprocessor: - return None + async def pause_generation( + self, + *, + wait_for_inflight_requests: bool = False, + clear_cache: bool = True, + ) -> None: + """Pause generation.""" + async with self._pause_cond: + if self._paused: + return + self._paused = True - async def get_tokenizer(self) -> TokenizerLike: - for stage in self.stage_list: - if stage.is_comprehension: - return stage.tokenizer - return None + # TODO: Implement request draining if wait_for_inflight_requests - async def is_tracing_enabled(self) -> bool: - for stage in self.stage_list: - if stage.is_comprehension: - return stage.is_tracing_enabled - return False + if clear_cache: + # Clear caches for all stages. + await self.reset_prefix_cache( + reset_running_requests=not wait_for_inflight_requests, + reset_connector=True, + ) + await self.reset_mm_cache() + await self.reset_encoder_cache() - @property - def renderer(self): - """Return the renderer from input_processor if available. + async def resume_generation(self) -> None: + """Resume generation.""" + async with self._pause_cond: + self._paused = False + self._pause_cond.notify_all() - OMNI: Required by upstream OpenAIServingModels.__init__ which - accesses engine_client.renderer. + async def is_paused(self) -> bool: + """Check if paused.""" + async with self._pause_cond: + return self._paused + + async def start_profile(self, stages: list[int] | None = None) -> list[Any]: + """Start profiling specified stages. + + TODO(AsyncOmni): normalize return payloads across LLM/diffusion stages. """ - if self.input_processor is None: - return None - return self.input_processor.renderer + return await self.collective_rpc(method="start_profile", stage_ids=stages) - async def do_log_stats(self) -> None: - pass + async def stop_profile(self, stages: list[int] | None = None) -> list[Any]: + """Stop profiling specified stages. - async def check_health(self) -> None: - pass + TODO(AsyncOmni): normalize return payloads across LLM/diffusion stages. + """ + return await self.collective_rpc(method="stop_profile", stage_ids=stages) async def reset_mm_cache(self) -> None: - pass - - async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: - pass + """Reset the multi-modal cache for all stages. - async def collective_rpc( - self, - method: str | Callable[..., _R], - timeout: float | None = None, - args: tuple = (), - kwargs: dict[str, Any] | None = None, - ) -> list[_R]: - """Execute an RPC call on all stages asynchronously. + TODO: Forward to Orchestrator process via message. + """ + logger.warning("[AsyncOmni] reset_mm_cache not yet supported with Orchestrator process") - Args: - method: Name of the method to execute or callable - timeout: Optional timeout in seconds - args: Positional arguments for the method - kwargs: Keyword arguments for the method + async def reset_encoder_cache(self) -> None: + """Reset the encoder cache for all stages. - Returns: - List of results from each stage + TODO: Forward to Orchestrator process via message. """ - self._run_output_handler() - # Run synchronous collective_rpc in thread pool to avoid blocking event loop - loop = asyncio.get_event_loop() - - async def run_stage_rpc(stage: OmniStage) -> _R: - return await loop.run_in_executor( - None, - stage.collective_rpc, - method, - timeout, - args, - kwargs, - ) + logger.warning("[AsyncOmni] reset_encoder_cache not yet supported with Orchestrator process") - # Run all stages concurrently - results = await asyncio.gather(*[run_stage_rpc(stage) for stage in self.stage_list]) - return list(results) + async def reset_prefix_cache( + self, + reset_running_requests: bool = False, + reset_connector: bool = False, + ) -> bool: + """Reset the prefix cache for all stages. + + TODO: Forward to Orchestrator process via message. + """ + logger.warning("[AsyncOmni] reset_prefix_cache not yet supported with Orchestrator process") + return True async def sleep(self, level: int = 1) -> None: + """Sleep all stages. + + Best-effort: unsupported stages will emit a TODO result. + """ self._is_sleeping = True await self.collective_rpc(method="sleep", args=(level,)) async def wake_up(self, tags: list[str] | None = None) -> None: + """Wake up all stages. + + Best-effort: unsupported stages will emit a TODO result. + """ self._is_sleeping = False await self.collective_rpc(method="wake_up", args=(tags,)) async def is_sleeping(self) -> bool: - """Check whether the engine is sleeping""" - return getattr(self, "_is_sleeping", False) + """Return whether all stages are sleeping. + + TODO(AsyncOmni): query the orchestrator once all stage backends expose + a real sleeping-state RPC. For now we track the requested state locally. + """ + return self._is_sleeping async def add_lora(self, lora_request: LoRARequest) -> bool: - """Load a new LoRA adapter into the engine for future requests.""" - result = await self.collective_rpc(method="add_lora", args=(lora_request,)) - return result[0][0] + """Load a new LoRA adapter into all stages. + + Returns True only if all concretely-implemented stages report success. + """ + results = await self.collective_rpc(method="add_lora", args=(lora_request,)) + concrete_results = [r for r in results if not (isinstance(r, dict) and r.get("todo"))] + return all(self._coerce_stage_bool(r) for r in concrete_results) if concrete_results else False async def remove_lora(self, adapter_id: int) -> bool: - """Remove a LoRA adapter from the engine.""" - result = await self.collective_rpc(method="remove_lora", args=(adapter_id,)) - return result[0][0] + """Remove a LoRA adapter from all stages. + + TODO(AsyncOmni): add richer per-stage error reporting to the public API. + """ + results = await self.collective_rpc(method="remove_lora", args=(adapter_id,)) + concrete_results = [r for r in results if not (isinstance(r, dict) and r.get("todo"))] + return all(self._coerce_stage_bool(r) for r in concrete_results) if concrete_results else False async def list_loras(self) -> list[int]: - """List all loaded LoRA adapter IDs.""" - result = await self.collective_rpc(method="list_loras") - return result[0][0] + """List all loaded LoRA adapter IDs across stages.""" + results = await self.collective_rpc(method="list_loras") + merged: set[int] = set() + for result in results: + if isinstance(result, dict) and result.get("todo"): + continue + if isinstance(result, list): + merged.update(result) + return sorted(merged) async def pin_lora(self, adapter_id: int) -> bool: - """Pin a LoRA adapter so it is not evicted from the cache.""" - result = await self.collective_rpc(method="pin_lora", args=(adapter_id,)) - return result[0][0] + """Pin a LoRA adapter across stages.""" + results = await self.collective_rpc(method="pin_lora", args=(adapter_id,)) + concrete_results = [r for r in results if not (isinstance(r, dict) and r.get("todo"))] + return all(self._coerce_stage_bool(r) for r in concrete_results) if concrete_results else False - async def encode( - self, - *args, - **kwargs, - ): - """Generate outputs for a request from a pooling model.""" - raise NotImplementedError("encode() is not implemented for AsyncOmni") + # ==================== Properties ==================== - async def start_profile(self, stages: list[int] | None = None) -> None: - """Start profiling for specified stages. - - Async wrapper around the base implementation for API consistency. - - Args: - stages: List of stage IDs to start profiling. If None, starts - profiling for all stages that have profiling enabled. - - Example: - >>> await async_omni.start_profile() - >>> async for output in async_omni.generate(...): - ... pass - >>> await async_omni.stop_profile() - """ - super().start_profile(stages) + @property + def is_running(self) -> bool: + """Check if the engine is running.""" + return self.final_output_task is not None and not self.final_output_task.done() - async def stop_profile(self, stages: list[int] | None = None) -> None: - """Stop profiling for specified stages. + @property + def errored(self) -> bool: + """Whether orchestrator thread has stopped unexpectedly.""" + return not self.engine.is_alive() - Async wrapper around the base implementation for API consistency. + @property + def is_stopped(self) -> bool: + """EngineClient abstract property implementation.""" + return self.errored - Args: - stages: List of stage IDs to stop profiling. If None, stops - profiling for all stages. - - Example: - >>> await async_omni.start_profile() - >>> async for output in async_omni.generate(...): - ... pass - >>> await async_omni.stop_profile() - """ - super().stop_profile(stages) + @property + def dead_error(self) -> BaseException: + """EngineClient abstract property implementation.""" + return EngineDeadError() - async def pause_generation( - self, - *, - wait_for_inflight_requests: bool = False, - clear_cache: bool = True, - ) -> None: - """ - Pause generation to allow model weight updates. + # ==================== EngineClient Interface ==================== - New generation/encoding requests are blocked until resume. + async def get_input_preprocessor(self) -> InputPreprocessor: + """Get input preprocessor.""" + return self.input_processor - Args: - wait_for_inflight_requests: When ``True`` waits for in-flight - requests to finish before pausing. When ``False`` (default), - immediately aborts any in-flight requests. - clear_cache: Whether to clear KV cache and prefix cache after - draining. Set to ``False`` to preserve cache for faster resume. - Default is ``True`` (clear caches). - """ + async def get_tokenizer(self) -> TokenizerLike: + """Get tokenizer for the comprehension stage.""" + stage_index = self._get_comprehension_stage_index() + if stage_index is not None: + tokenizer = self.engine.output_processors[stage_index].tokenizer + if tokenizer is not None: + return tokenizer + return self.input_processor.tokenizer # type: ignore[return-value] - async with self._pause_cond: - if self._paused: - return - self._paused = True + async def is_tracing_enabled(self) -> bool: + """Check if tracing is enabled.""" + return False - # Note: AsyncOmni uses a stage-based architecture without a central - # output_processor. For now, we simply set the pause flag and let - # new requests wait. In-flight requests will complete naturally. - # TODO: Implement request abortion for stages if needed. + async def do_log_stats(self) -> None: + """Log statistics. - # Clear cache if requested - if clear_cache: - await self.reset_prefix_cache() - await self.reset_mm_cache() + TODO: Forward to Orchestrator process via message. + """ + pass - async def resume_generation(self) -> None: - """Resume generation after :meth:`pause_generation`.""" + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + """Return the task set exposed by the orchestrator-backed engine.""" + return tuple(self.engine.supported_tasks) - async with self._pause_cond: - self._paused = False - self._pause_cond.notify_all() # Wake up all waiting requests + async def check_health(self) -> None: + """Check engine health by verifying the Orchestrator process is alive.""" + OmniBase.check_health(self) - async def is_paused(self) -> bool: - """Return whether the engine is currently paused.""" + # ==================== Shutdown ==================== - async with self._pause_cond: - return self._paused + def shutdown(self) -> None: + """Shutdown the engine.""" + if self.final_output_task is not None: + self.final_output_task.cancel() + self.final_output_task = None + OmniBase.shutdown(self) diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 37d358d451..08812223db 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -10,6 +10,7 @@ import asyncio import uuid +import weakref from collections.abc import AsyncGenerator, Iterable from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -27,6 +28,18 @@ logger = init_logger(__name__) +def _weak_close_async_omni_diffusion(engine: DiffusionEngine, executor: ThreadPoolExecutor) -> None: + """Best-effort diffusion cleanup for GC finalization.""" + try: + engine.close() + except Exception: + pass + try: + executor.shutdown(wait=False) + except Exception: + pass + + class AsyncOmniDiffusion: """Async entry point for vLLM-Omni diffusion model inference. @@ -127,6 +140,12 @@ def __init__( # Thread pool for running sync engine in async context self._executor = ThreadPoolExecutor(max_workers=1) self._closed = False + self._weak_finalizer = weakref.finalize( + self, + _weak_close_async_omni_diffusion, + self.engine, + self._executor, + ) logger.info("AsyncOmniDiffusion initialized with model: %s", model) @@ -217,6 +236,9 @@ def close(self) -> None: if self._closed: return self._closed = True + finalizer = getattr(self, "_weak_finalizer", None) + if finalizer is not None and finalizer.alive: + finalizer.detach() try: self.engine.close() @@ -234,13 +256,6 @@ def shutdown(self) -> None: """Alias for close() method.""" self.close() - def __del__(self) -> None: - """Best-effort cleanup on deletion.""" - try: - self.close() - except Exception: - pass - async def abort(self, request_id: str | Iterable[str]) -> None: """Abort a request.""" self.engine.abort(request_id) diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py deleted file mode 100644 index d63e4f154b..0000000000 --- a/vllm_omni/entrypoints/async_omni_llm.py +++ /dev/null @@ -1,216 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio -import os -import socket -from typing import TYPE_CHECKING - -import torch -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.renderers import renderer_from_config -from vllm.tracing import init_tracer -from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.usage.usage_lib import UsageContext -from vllm.utils.func_utils import deprecate_kwargs -from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.engine.core_client import EngineCoreClient -from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager - -from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs -from vllm_omni.engine.input_processor import OmniInputProcessor -from vllm_omni.engine.output_processor import MultimodalOutputProcessor - -if TYPE_CHECKING: - pass - -logger = init_logger(__name__) - - -class AsyncOmniLLM(AsyncLLM): - """Async single-stage LLM engine for use within a stage worker process. - - This class extends the base vLLM AsyncLLM class with omni-specific - processors for handling multimodal inputs and outputs. It is used - internally by AsyncOmniStage workers and should not be instantiated - directly by users. - - Args: - engine_args: AsyncOmniEngineArgs containing engine configuration - vllm_config: Global vLLM configuration - executor_class: Executor implementation class, e.g. MultiprocExecutor - log_stats: Whether to log statistics - usage_context: Usage context of the LLM (default: ENGINE_CONTEXT) - mm_registry: Multi-modal registry for processing multimodal inputs - use_cached_outputs: Whether to use cached outputs - log_requests: Whether to log requests - start_engine_loop: Whether to start the engine loop automatically - stat_loggers: Customized stat loggers for the engine. - If not provided, default stat loggers will be used. - Note: Stat logger interface may change in V1. - client_addresses: Optional dictionary mapping client names to addresses - client_count: Total number of clients (default: 1) - client_index: Index of this client (default: 0) - """ - - def __init__( - self, - engine_args: AsyncOmniEngineArgs, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - use_cached_outputs: bool = False, - log_requests: bool = True, - start_engine_loop: bool = True, - stat_loggers: list[StatLoggerFactory] | None = None, - client_addresses: dict[str, str] | None = None, - client_count: int = 1, - client_index: int = 0, - ) -> None: - """ - Create an AsyncOmniLLM. - - Args: - vllm_config: global configuration. - executor_class: an Executor impl, e.g. MultiprocExecutor. - log_stats: Whether to log stats. - usage_context: Usage context of the LLM. - mm_registry: Multi-modal registry. - use_cached_outputs: Whether to use cached outputs. - log_requests: Whether to log requests. - start_engine_loop: Whether to start the engine loop. - stat_loggers: customized stat loggers for the engine. - If not provided, default stat loggers will be used. - PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE - IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. - - Returns: - None - """ - # Ensure we can serialize custom transformer configs - maybe_register_config_serialize_by_value() - - self.model_config = vllm_config.model_config - self.vllm_config = vllm_config - self.observability_config = vllm_config.observability_config - self.log_requests = log_requests - - self.log_stats = log_stats or (stat_loggers is not None) - if not log_stats and stat_loggers is not None: - logger.info( - "AsyncLLM created with log_stats=False and non-empty custom logger list; " - "enabling logging without default stat loggers" - ) - - # InputProcessor (converts Inputs --> EngineCoreRequests). - self.input_processor = OmniInputProcessor( - vllm_config=vllm_config, - mm_registry=mm_registry, - ) - - self.renderer = renderer_from_config(self.vllm_config) - - # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). - self.output_processor = MultimodalOutputProcessor( - tokenizer=self.renderer.tokenizer, - log_stats=self.log_stats, - engine_core_output_type=engine_args.engine_output_type, - ) - - if self.observability_config.otlp_traces_endpoint is not None: - tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) - self.output_processor.tracer = tracer - - # Pause / resume state for async RL workflows. - self._pause_cond = asyncio.Condition() - self._paused = False - - # EngineCore (starts the engine in background process). - self.engine_core = EngineCoreClient.make_async_mp_client( - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=self.log_stats, - client_addresses=client_addresses, - client_count=client_count, - client_index=client_index, - ) - - # Loggers. - self.logger_manager: StatLoggerManager | None = None - if self.log_stats: - self.logger_manager = StatLoggerManager( - vllm_config=vllm_config, - engine_idxs=self.engine_core.engine_ranks_managed, - custom_stat_loggers=stat_loggers, - enable_default_loggers=log_stats, - client_count=client_count, - ) - self.logger_manager.log_engine_initialized() - - self.output_handler: asyncio.Task | None = None - try: - # Start output handler eagerly if we are in the asyncio eventloop. - asyncio.get_running_loop() - self._run_output_handler() - except RuntimeError: - pass - - # Use profiler_config from vllm_config (new way, aligned with vllm v1) - if vllm_config.profiler_config.profiler == "torch" and not vllm_config.profiler_config.ignore_frontend: - profiler_dir = vllm_config.profiler_config.torch_profiler_dir - logger.info( - "Torch profiler enabled. AsyncOmniLLM CPU traces will be collected under %s", - profiler_dir, - ) - worker_name = f"{socket.gethostname()}_{os.getpid()}.async_omni_llm" - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - ], - with_stack=vllm_config.profiler_config.torch_profiler_with_stack, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - profiler_dir, - worker_name=worker_name, - use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip, - ), - ) - else: - self.profiler = None - - @classmethod - @deprecate_kwargs( - "disable_log_requests", - additional_message=("This argument will have no effect. Use `enable_log_requests` instead."), - ) - def from_vllm_config( - cls, - vllm_config: VllmConfig, - engine_args: AsyncOmniEngineArgs, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: list[StatLoggerFactory] | None = None, - enable_log_requests: bool = False, - disable_log_stats: bool = False, - client_addresses: dict[str, str] | None = None, - client_count: int = 1, - client_index: int = 0, - disable_log_requests: bool = True, # Deprecated, will be removed - ) -> "AsyncLLM": - # Create the LLMEngine. - return cls( - vllm_config=vllm_config, - executor_class=Executor.get_class(vllm_config), - start_engine_loop=start_engine_loop, - stat_loggers=stat_loggers, - log_requests=enable_log_requests, - log_stats=not disable_log_stats, - usage_context=usage_context, - client_addresses=client_addresses, - client_count=client_count, - client_index=client_index, - engine_args=engine_args, - ) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 7365e73bb1..a70ac64b6a 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -8,37 +8,20 @@ import argparse import json import os -import signal from typing import Any -import msgspec.msgpack import uvloop -import zmq from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG from vllm.logger import init_logger from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm.utils.network_utils import make_zmq_socket -from vllm.v1.utils import get_engine_client_zmq_addr - -from vllm_omni.distributed.omni_connectors import ( - get_connectors_config_for_stage, - load_omni_transfer_config, -) -from vllm_omni.distributed.omni_connectors.utils.initialization import ( - resolve_omni_kv_config_for_stage, -) + from vllm_omni.entrypoints.cli.logo import log_logo -from vllm_omni.entrypoints.omni import OmniBase, omni_snapshot_download -from vllm_omni.entrypoints.omni_stage import OmniStage from vllm_omni.entrypoints.openai.api_server import omni_run_server -from vllm_omni.entrypoints.utils import inject_omni_kv_config logger = init_logger(__name__) -HANDSHAKE_TIMEOUT_MINS = 5 - DESCRIPTION = """Launch a local OpenAI-compatible API server to serve Omni models via HTTP. Supports both multi-stage LLM models and diffusion models. @@ -372,123 +355,33 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu def _create_default_diffusion_stage_cfg(args: argparse.Namespace) -> list[dict[str, Any]]: - omni_base = OmniBase.__new__(OmniBase) - return omni_base._create_default_diffusion_stage_cfg(vars(args)) + """Create default diffusion stage configuration. + + Uses AsyncOmniEngine's implementation which doesn't have OmegaConf + compatibility issues. + """ + from vllm_omni.engine.async_omni_engine import AsyncOmniEngine + + return AsyncOmniEngine._create_default_diffusion_stage_cfg(vars(args)) def run_headless(args: argparse.Namespace) -> None: - if args.api_server_count is not None and args.api_server_count > 1: - raise ValueError("api_server_count can't be set in headless mode") - if args.worker_backend != "multi_process": - raise ValueError("headless mode requires worker_backend=multi_process") - - model = omni_snapshot_download(args.model) - - omni_base = OmniBase.__new__(OmniBase) - args_dict = vars(args).copy() - args_dict["model"] = model - config_path, stage_configs = omni_base._resolve_stage_configs(model, args_dict) - - single_stage_id = args.stage_id - if single_stage_id is None: - if len(stage_configs) != 1: - raise ValueError("--stage-id is required in headless mode for multi-stage configs") - single_stage_id = getattr(stage_configs[0], "stage_id", 0) - - stage_config = None - for cfg in stage_configs: - if getattr(cfg, "stage_id", None) == single_stage_id: - stage_config = cfg - break - if stage_config is None: - raise ValueError(f"No stage matches stage_id={single_stage_id}.") - - # TODO(wuhang): Support connectors config by cli - transfer_config = load_omni_transfer_config(config_path, default_shm_threshold=args.shm_threshold_bytes) - connectors_config = get_connectors_config_for_stage(transfer_config, single_stage_id) - - omni_master_address = args.omni_master_address - omni_master_port = args.omni_master_port - - # Perform handshake with orchestrator to get dynamically allocated endpoints - with zmq.Context() as zmq_ctx: - handshake_endpoint = get_engine_client_zmq_addr( - local_only=False, host=omni_master_address, port=omni_master_port - ) - - with make_zmq_socket(zmq_ctx, handshake_endpoint, zmq.REQ, bind=False, linger=5000) as handshake_socket: - # TODO(wuhang): Define protocol in python dataclass. - handshake_msg = {"type": "handshake", "stage_id": single_stage_id} - handshake_socket.send(msgspec.msgpack.encode(handshake_msg)) - - # Wait for response with timeout - if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): - raise RuntimeError( - f"Handshake timeout ({HANDSHAKE_TIMEOUT_MINS} minutes) for stage-{single_stage_id} " - f"at {handshake_endpoint}" - ) - - try: - response = msgspec.msgpack.decode(handshake_socket.recv()) - except msgspec.DecodeError as exc: - raise RuntimeError( - f"Handshake decode failed for stage-{single_stage_id} at {handshake_endpoint}: {exc}" - ) from exc - except Exception as exc: # pragma: no cover - unexpected decode errors - raise RuntimeError( - f"Unexpected error decoding handshake for stage-{single_stage_id} at {handshake_endpoint}: {exc}" - ) from exc - - if not response["ok"]: - error_msg = response["error"] - raise RuntimeError(f"Handshake failed for stage-{single_stage_id}: {error_msg}") - - in_endpoint, out_endpoint = response["in_endpoint"], response["out_endpoint"] - - logger.info( - f"[Headless] Stage-{single_stage_id} received endpoints via handshake: " - f"in={in_endpoint}, out={out_endpoint}" - ) - - shutdown_requested = False - - def signal_handler(signum, frame): - nonlocal shutdown_requested - if shutdown_requested: - return - shutdown_requested = True - raise SystemExit - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - - stage = OmniStage(stage_config, stage_init_timeout=args.stage_init_timeout) - stage.attach_queues(in_endpoint, out_endpoint) - - # Inject YAML-resolved connector config into omni_kv_config for in-engine usage. - try: - omni_conn_cfg, omni_from, omni_to = resolve_omni_kv_config_for_stage(transfer_config, single_stage_id) - if omni_conn_cfg: - inject_omni_kv_config(stage, omni_conn_cfg, omni_from, omni_to) # type: ignore - except Exception as e: - logger.debug("[Headless] Failed to inject omni connector config into stage-%s: %s", single_stage_id, e) - - old_env = os.environ.get("VLLM_LOGGING_PREFIX") - os.environ["VLLM_LOGGING_PREFIX"] = f"[Stage-{single_stage_id}] {'' if old_env is None else old_env}" - try: - stage.init_stage_worker( - model, - is_async=True, - shm_threshold_bytes=int(args.shm_threshold_bytes), - batch_timeout=int(args.batch_timeout), - connectors_config=connectors_config, - worker_backend="multi_process", - ignore_runtime_config=True, - ) - if stage._proc is not None: - stage._proc.join() - finally: - stage.stop_stage_worker() + """Run a single stage in headless mode. + + .. deprecated:: 0.x.x + Headless mode is deprecated and will be removed in a future version. + It is only compatible with the old OmniStage-based runtime. + The current AsyncOmniEngine-based runtime does not support headless mode. + + Raises: + RuntimeError: Always raises an error indicating headless mode is deprecated. + """ + raise RuntimeError( + "Headless mode is deprecated and not supported in the current runtime. " + "Please use the standard orchestrator mode (without --headless flag). " + "If you need distributed deployment, consider using Ray backend or " + "other distributed serving solutions." + ) def cmd_init() -> list[CLISubcommand]: diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 488b986d8e..a3bfe98ce2 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -1,1004 +1,42 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import json -import multiprocessing as mp -import os -import threading +from __future__ import annotations + +import copy import time import uuid -import weakref -from collections.abc import Callable, Generator, Sequence -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Literal, TypeVar, overload +from collections.abc import Callable, Generator, Iterable, Sequence +from typing import TYPE_CHECKING, Literal, overload -import huggingface_hub -import msgspec.msgpack -import torch -import zmq from tqdm.auto import tqdm -from vllm import SamplingParams from vllm.logger import init_logger -from vllm.utils.network_utils import make_zmq_socket -from vllm.v1.utils import get_engine_client_zmq_addr +from vllm.sampling_params import RequestOutputKind -from vllm_omni.config.stage_config import StageConfigFactory -from vllm_omni.config.yaml_util import create_config -from vllm_omni.distributed.omni_connectors import ( - get_stage_connector_config, - initialize_orchestrator_connectors, -) -from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector -from vllm_omni.distributed.omni_connectors.utils.initialization import ( - resolve_omni_kv_config_for_stage, -) -from vllm_omni.distributed.ray_utils.utils import ( - create_placement_group, - get_ray_queue_class, - try_close_ray, -) -from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker -from vllm_omni.entrypoints.omni_stage import OmniStage -from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType -from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load -from vllm_omni.entrypoints.utils import ( - filter_dataclass_kwargs, - get_final_stage_id_for_e2e, - inject_omni_kv_config, - load_and_resolve_stage_configs, -) -from vllm_omni.entrypoints.zmq_utils import ZmqQueue -from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams -from vllm_omni.lora.request import LoRARequest -from vllm_omni.metrics import OrchestratorAggregator, StageRequestStats -from vllm_omni.model_executor.model_loader.weight_utils import ( - download_weights_from_hf_specific, -) +from vllm_omni.entrypoints.client_request_state import ClientRequestState +from vllm_omni.entrypoints.omni_base import OmniBase +from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.outputs import OmniRequestOutput -_R = TypeVar("_R") +if TYPE_CHECKING: + from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams logger = init_logger(__name__) -def _weak_close_cleanup( - stage_list, - stage_in_queues, - stage_out_queues, - ray_pg, - zmq_ctx=None, - handshake_stop: threading.Event | None = None, - zmq_handshake_socket: zmq.Socket | None = None, - handshake_thread: threading.Thread | None = None, -): - """Weak reference cleanup function for OmniBase instances.""" - if stage_list: - for q in stage_in_queues: - try: - q.put_nowait(SHUTDOWN_TASK) - except Exception as e: - logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") - close_fn = getattr(q, "close", None) - if callable(close_fn): - close_fn() - for q in stage_out_queues: - close_fn = getattr(q, "close", None) - if callable(close_fn): - close_fn() - for stage in stage_list: - try: - stage.stop_stage_worker() - except Exception as e: - logger.warning(f"Failed to stop stage worker: {e}") - try_close_ray(ray_pg) - - # Gracefully shutdown handshake server thread - if handshake_stop is not None: - handshake_stop.set() - if handshake_thread is not None: - handshake_thread.join(timeout=2.0) - if handshake_thread.is_alive(): - logger.warning("Handshake server thread did not terminate gracefully within timeout") - - # Close ZMQ resources after thread has exited - if zmq_handshake_socket is not None: - zmq_handshake_socket.close(0) - if zmq_ctx is not None: - zmq_ctx.term() - - -def _dummy_snapshot_download(model_id): - return model_id - - -def omni_snapshot_download(model_id) -> str: - # If it's already a local path, just return it - if os.path.exists(model_id): - return model_id - # TODO: this is just a workaround for quickly use modelscope, we should support - # modelscope in weight loading feature instead of using `snapshot_download` - if os.environ.get("VLLM_USE_MODELSCOPE", False): - from modelscope.hub.snapshot_download import snapshot_download - - return snapshot_download(model_id) - # For other cases (Hugging Face), perform a real download to ensure all - # necessary files (including *.pt for audio/diffusion) are available locally - # before stage workers are spawned. This prevents initialization timeouts. - # Return the original model_id so that model_config.model preserves - # HuggingFace semantics (e.g. "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice") - # instead of the resolved cache path. - try: - download_weights_from_hf_specific( - model_name_or_path=model_id, - cache_dir=None, - allow_patterns=["*"], - require_all=True, - ) - except huggingface_hub.errors.RepositoryNotFoundError: - logger.warning(f"Repository not found for '{model_id}'.") - return model_id - - -class OmniBase: - """Base class for serving Omni models. - - Args: - model: Model name or path to load. - **kwargs: Arbitrary keyword arguments. - - stage_configs_path: Optional path to YAML file containing stage - configurations. If None, configurations are loaded from the model. - - log_stats: Whether to enable statistics logging - be written to files with stage-specific suffixes. - - stage_init_timeout: Per-stage init watchdog (seconds). Measured from - when the previous stage finished (possibly a prior Omni run with GPU - reuse/overlap) to when the current stage starts to initialize. - - shm_threshold_bytes: Threshold in bytes for using shared memory - for IPC. Objects larger than this threshold will use shared memory. - - worker_backend: Backend for worker processes. Default is "multi_process". - - ray_address: Address of Ray cluster for Ray backend, if using Ray backend. - - batch_timeout: Timeout in seconds for batching requests within a stage - - init_timeout: Timeout in seconds for waiting for all stages to initialize - - Additional keyword arguments passed to stage engines. - """ - - def __init__(self, model: str, **kwargs: Any) -> None: - model = omni_snapshot_download(model) - kwargs["model"] = model - - # Stage management attributes - self.stage_list: list[OmniStage] = [] - self._stage_in_queues: list[Any] = [] - self._stage_out_queues: list[Any] = [] - self._stages_ready: set[int] = set() - self._ray_pg = None - self._queue_cls = None - self._ctx = None - self._zmq_ctx: zmq.Context | None = None - self._zmq_master_address: str | None = None - self._zmq_master_port: int | None = None - self._zmq_handshake_socket: zmq.Socket | None = None - self._handshake_thread: threading.Thread | None = None - self._handshake_stop: threading.Event | None = None - self._handshake_endpoints: dict[int, tuple[str, str]] = {} - self._handshake_seen: set[int] = set() # Track which stage IDs have completed ZMQ handshake - self._single_stage_id: int | None = None # Optional: deploy only a specific stage ID - - # Sleep mode tracking - self._is_sleeping: bool = False - - # RPC results storage: {stage_id: {rpc_id: result}} - # Used by collective_rpc to retrieve results collected from the output queue - self._rpc_results: dict[int, dict[str, dict[str, Any]]] = {} - - # Initialize stages - each stage will create appropriate instance based on stage_type - # Stage workers will automatically create OmniLLM or OmniDiffusion instances - # based on stage_type in YAML config (handled in omni_stage.py) - logger.info(f"Initializing stages for model: {model}") - self._initialize_stages(model, kwargs) - - def _get_default_cache_config(self, cache_backend: str | None) -> dict[str, Any] | None: - if cache_backend == "cache_dit": - return { - "Fn_compute_blocks": 1, - "Bn_compute_blocks": 0, - "max_warmup_steps": 4, - "residual_diff_threshold": 0.24, - "max_continuous_cached_steps": 3, - "enable_taylorseer": False, - "taylorseer_order": 1, - "scm_steps_mask_policy": None, - "scm_steps_policy": "dynamic", - } - if cache_backend == "tea_cache": - return { - "rel_l1_thresh": 0.2, - } - return None - - def _normalize_cache_config(self, cache_backend: str | None, cache_config: Any | None) -> Any | None: - if isinstance(cache_config, str): - try: - cache_config = json.loads(cache_config) - except json.JSONDecodeError: - logger.warning("Invalid cache_config JSON, using defaults.") - cache_config = None - if cache_config is None and cache_backend not in (None, "", "none"): - cache_config = self._get_default_cache_config(cache_backend) - return cache_config - - def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> list[dict[str, Any]]: - """Create default diffusion stage configuration. - - Uses StageConfigFactory for typed configuration creation while - maintaining backward compatibility with the legacy format. - - Args: - kwargs: Engine arguments from CLI/API. - - Returns: - List containing a single OmegaConf config for the diffusion stage. - """ - # Normalize dtype - if "dtype" in kwargs and not isinstance(kwargs["dtype"], str): - if not isinstance(kwargs["dtype"], torch.dtype): - raise TypeError(f"Provided dtype must be a string or torch.dtype, got {type(kwargs['dtype']).__name__}") - kwargs["dtype"] = str(kwargs["dtype"]).removeprefix("torch.") - - # Normalize cache config before passing to factory - cache_backend = kwargs.get("cache_backend", "none") - cache_config = self._normalize_cache_config(cache_backend, kwargs.get("cache_config", None)) - - # Update kwargs with normalized values - kwargs_copy = dict(kwargs) - kwargs_copy["cache_backend"] = cache_backend - kwargs_copy["cache_config"] = cache_config - - # Use the factory to create default diffusion config - return StageConfigFactory.create_default_diffusion(kwargs_copy) - - def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[str, list[Any]]: - """Resolve stage configs and inject defaults shared by orchestrator/headless.""" - # TODO(wuhang): - # Remove kwargs as parameters in the future. - # Use dataclass directly for engine args. - - stage_configs_path = kwargs.get("stage_configs_path", None) - - # TTS-specific CLI overrides - self.tts_max_instructions_length: int | None = kwargs.get("tts_max_instructions_length", None) - - # Load stage configurations from YAML - config_path, stage_configs = load_and_resolve_stage_configs( - model, - stage_configs_path, - kwargs, - default_stage_cfg_factory=lambda: self._create_default_diffusion_stage_cfg(kwargs), - ) - - # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. - for cfg in stage_configs: - try: - if getattr(cfg, "stage_type", None) != "diffusion": - continue - if not hasattr(cfg, "engine_args") or cfg.engine_args is None: - cfg.engine_args = create_config({}) - if kwargs.get("lora_path") is not None: - if not hasattr(cfg.engine_args, "lora_path") or cfg.engine_args.lora_path is None: - cfg.engine_args.lora_path = kwargs["lora_path"] - lora_scale = kwargs.get("lora_scale") - if lora_scale is None: - # Backwards compatibility for older callers. - lora_scale = kwargs.get("static_lora_scale") - if lora_scale is not None: - if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: - cfg.engine_args.lora_scale = lora_scale - quantization_config = kwargs.get("quantization_config") - if quantization_config is not None: - if ( - not hasattr(cfg.engine_args, "quantization_config") - or cfg.engine_args.quantization_config is None - ): - cfg.engine_args.quantization_config = quantization_config - except Exception as e: - logger.warning("Failed to inject LoRA config for stage: %s", e) - - return config_path, stage_configs - - def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: - """Initialize stage list management.""" - self._inline_diffusion = False - self._inline_engine = None - - stage_init_timeout = kwargs.get("stage_init_timeout", 20) - shm_threshold_bytes = kwargs.get("shm_threshold_bytes", 65536) - init_timeout = kwargs.get("init_timeout", 300) - worker_backend = kwargs.get("worker_backend", "multi_process") - ray_address = kwargs.get("ray_address", None) - batch_timeout = kwargs.get("batch_timeout", 10) - log_stats = kwargs.get("log_stats", False) - self._single_stage_id = kwargs.get("stage_id", None) - self._zmq_master_address = kwargs.get("omni_master_address", None) - if self._zmq_master_address is None: - self._zmq_master_address = "127.0.0.1" - logger.info("No omni_master_address provided, defaulting to localhost (127.0.0.1)") - self._zmq_master_port = kwargs.get("omni_master_port", None) - - # Resolve stage configs shared by orchestrator/headless paths. - self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs) - - # Initialize connectors - self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( - self.config_path, worker_backend=worker_backend, shm_threshold_bytes=shm_threshold_bytes - ) - - # Initialize stats paths - self.log_stats: bool = bool(log_stats) - - self.worker_backend = worker_backend - self.ray_address = ray_address - self.batch_timeout = batch_timeout - # async chunk remains the same for each stage - self.async_chunk = self._is_async_chunk_enable(self.stage_configs) - - # Build OmniStage instances in parallel, preserve original order - def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: - idx, cfg = idx_cfg - return idx, OmniStage(cfg, stage_init_timeout=stage_init_timeout) - - with ThreadPoolExecutor(max_workers=min(len(self.stage_configs), max(1, os.cpu_count() or 1))) as executor: - futures = [executor.submit(_build_stage, (idx, cfg)) for idx, cfg in enumerate(self.stage_configs)] - results: list[tuple[int, OmniStage]] = [] - for fut in as_completed(futures): - results.append(fut.result()) - results.sort(key=lambda x: x[0]) - self.stage_list = [st for _, st in results] - self.default_sampling_params_list = [st.default_sampling_params for st in self.stage_list] - self.output_modalities = [st.final_output_type for st in self.stage_list] - logger.info(f"[{self._name}] Loaded {len(self.stage_list)} stages") - - # Phase 1 optimization: for a single diffusion stage in async mode, - # run the engine directly in the orchestrator process to eliminate - # the stage worker subprocess and its IPC serialization overhead. - if len(self.stage_list) == 1 and self.stage_list[0].stage_type == "diffusion" and self.is_async: - self._init_inline_diffusion_engine(model, self.stage_configs[0], kwargs) - return - - if self.worker_backend == "ray": - self._queue_cls = get_ray_queue_class() - else: - self._ctx = mp.get_context("spawn") - self._queue_cls = lambda: self._ctx.Queue(maxsize=0) - - self._stage_init_timeout = max(0, int(stage_init_timeout)) - self._shm_threshold_bytes = max(0, int(shm_threshold_bytes)) - self._start_stages(model) - # Wait for all stages to report readiness before seeding - self._wait_for_stages_ready(timeout=init_timeout) - # Set up RPC result checkers so that collective_rpc works - self._setup_rpc_result_checkers() - - def _init_inline_diffusion_engine( - self, - model: str, - stage_config: Any, - kwargs: dict[str, Any], - ) -> None: - """Initialize diffusion engine directly in the orchestrator process. - - For single-stage diffusion pipelines, this eliminates the stage worker - subprocess and the associated Hop3 IPC serialization overhead. - GPU workers for tensor parallelism are still spawned by the - DiffusionExecutor as separate processes. - """ - from vllm_omni.diffusion.data import OmniDiffusionConfig - from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion - from vllm_omni.entrypoints.stage_utils import ( - _to_dict, - load_func_from_config, - set_stage_devices, - ) - - stage_id = stage_config.stage_id - engine_args = _to_dict(stage_config.engine_args) - runtime_cfg = _to_dict(getattr(stage_config, "runtime", {})) - - if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - - try: - from vllm_omni.platforms import current_omni_platform - - device_type = current_omni_platform.device_type - set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) - except Exception as e: - logger.warning("Device setup for inline diffusion failed: %s", e) - - engine_args = filter_dataclass_kwargs(OmniDiffusionConfig, engine_args) - engine_args.pop("model_stage", None) - engine_args.pop("model", None) - - cfg_kv_collect_func = load_func_from_config(getattr(stage_config, "cfg_kv_collect_func", None)) - - self._inline_engine = OmniDiffusion( - model=model, - stage_id=stage_id, - engine_input_source=getattr(stage_config, "engine_input_source", []), - cfg_kv_collect_func=cfg_kv_collect_func, - **engine_args, - ) - self._inline_diffusion = True - - # These attributes are normally set by AsyncOmni._wait_for_stages_ready - # but we skip that for inline mode. Set them to None since there is no - # LLM stage to provide them. - self.input_processor = None - self.io_processor = None - self.model_config = None - - logger.info( - "[%s] Inline diffusion mode active – stage worker subprocess bypassed", - self._name, - ) - - def _is_async_chunk_enable(self, stage_args: list) -> bool: - """get async chunk flag""" - engine_args = getattr(stage_args[0], "engine_args", None) - return bool(getattr(engine_args, "async_chunk", False)) - - def _start_stages(self, model: str) -> None: - """Start all stage processes.""" - if self.worker_backend == "ray": - # Initialize Ray Cluster - self._ray_pg = create_placement_group( - number_of_stages=len(self.stage_list), address=self.ray_address, strategy="PACK" - ) - else: - # Initialize ZMQ context - if self._zmq_ctx is None: - self._zmq_ctx = zmq.Context() - - # Allocate endpoints for each stage - total_stages = len(self.stage_configs) - self._handshake_endpoints = {} - - # If --stage-id is not set, use local_only mode - local_only = self._single_stage_id is None - - for sid in range(total_stages): - in_endpoint = get_engine_client_zmq_addr(local_only=local_only, host=self._zmq_master_address) - out_endpoint = get_engine_client_zmq_addr(local_only=local_only, host=self._zmq_master_address) - self._handshake_endpoints[sid] = (in_endpoint, out_endpoint) - logger.debug( - f"[{self._name}] Allocated endpoints for stage-{sid}: in={in_endpoint}, out={out_endpoint}" - ) - - # Start handshake server - self.start_handshake_server() - - for stage_id, stage in enumerate[OmniStage](self.stage_list): - if self.worker_backend == "ray": - in_q = self._queue_cls() - out_q = self._queue_cls() - else: - in_endpoint, out_endpoint = self._handshake_endpoints[stage_id] - in_q = ZmqQueue(self._zmq_ctx, zmq.PUSH, bind=in_endpoint) - out_q = ZmqQueue(self._zmq_ctx, zmq.PULL, bind=out_endpoint) - - self._stage_in_queues.append(in_q) - self._stage_out_queues.append(out_q) - stage.attach_queues(in_q, out_q) - - stage_connectors_config = get_stage_connector_config( - self.omni_transfer_config, - stage_id, - ) - - # Inject YAML-resolved connector config into omni_kv_config for - # in-engine usage (GPU model runner reads model_config.omni_kv_config). - try: - omni_conn_cfg, omni_from, omni_to = resolve_omni_kv_config_for_stage( - self.omni_transfer_config, stage_id - ) - if omni_conn_cfg: - inject_omni_kv_config(stage, omni_conn_cfg, omni_from, omni_to) # type: ignore - - except Exception as e: - logger.debug("[Omni] Failed to inject omni connector config into stage-%s: %s", stage_id, e) - - if self._single_stage_id is not None and stage_id != int(self._single_stage_id): - logger.info( - f"[{self._name}] Skipping initialization of stage-{stage_id} worker due to single_stage_id setting" - ) - continue - - stage.init_stage_worker( - model, - is_async=self.is_async, - shm_threshold_bytes=self._shm_threshold_bytes, - ctx=self._ctx if self.worker_backend != "ray" else None, - batch_timeout=self.batch_timeout, - connectors_config=stage_connectors_config, - worker_backend=self.worker_backend, - ray_placement_group=self._ray_pg, - ignore_runtime_config=True if self._single_stage_id is not None else False, - ) - - logger.debug(f"[{self._name}] Stage-{stage_id} process started") - - def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str, Any]) -> None: - self._stages_ready.add(stage_id) - logger.info(f"[{self._name}] Stage-{stage_id} reported ready") - - def _wait_for_stages_ready(self, timeout: int = 120) -> None: - """Wait for all stages to report readiness with optimized polling.""" - if self._single_stage_id is not None and self.worker_backend != "ray": - timeout = self._wait_for_handshakes(timeout) - - num_stages = len(self.stage_list) - deadline = time.time() + max(0, int(timeout)) - - logger.info(f"[{self._name}] Waiting for {num_stages} stages to initialize (timeout: {timeout}s)") - - while len(self._stages_ready) < num_stages and time.time() < deadline: - progressed = False - for stage_id, stage in enumerate(self.stage_list): - if stage_id in self._stages_ready: - continue - - # Check if the stage has reported status - if result := stage.try_collect(): - progressed = True - if result.get("type") == "stage_ready": - self._process_stage_ready(stage, stage_id, result) - - if not progressed: - time.sleep(0.05) - - # Handle Final State - if len(self._stages_ready) == num_stages: - logger.info(f"[{self._name}] All stages initialized successfully") - return - - # Handle Timeout/Failure - not_ready = sorted(set(range(num_stages)) - set(self._stages_ready)) - logger.warning( - f"[{self._name}] Initialization timeout: {len(self._stages_ready)}/{num_stages} " - f"stages ready. Missing stages: {not_ready}" - ) - - suggestions = [ - f"Ignore this warning if the model weight download / load from disk time is longer than {timeout}s.", - "Verify GPU/device assignment in config (runtime.devices) is correct.", - "Check GPU/host memory availability; reduce model or batch size if needed.", - "Check model weights path and network reachability (if loading remotely).", - "Increase initialization wait time (stage_init_timeout or call-site timeout).", - ] - - formatted_suggestions = "\n".join(f" {i + 1}) {msg}" for i, msg in enumerate(suggestions)) - - logger.warning(f"[{self._name}] Stage initialization timeout. Troubleshooting Steps:\n{formatted_suggestions}") - - def _is_profiler_enabled(self, stage_id: int) -> bool: - """Check if profiler config is set for a given stage.""" - stage = self.stage_list[stage_id] - # For diffusion stages, profiling is controlled by VLLM_TORCH_PROFILER_DIR env var - if stage.stage_type == "diffusion": - return True - # For LLM stages, check if profiler_config is set in engine_args - engine_args = getattr(stage.stage_config, "engine_args", None) - if engine_args is None: - return False - profiler_config = getattr(engine_args, "profiler_config", None) - if profiler_config is None: - return False - profiler = getattr(profiler_config, "profiler", None) - return profiler is not None - - def _setup_rpc_result_checkers(self) -> None: - """Set up RPC result checkers for all stages. - - Each checker reads from the shared ``_rpc_results`` dict so that - ``OmniStage.collective_rpc`` can retrieve results that were - collected by the orchestrator (or, in the sync path, by the - stage-level checker that drains the output queue). - - Uses a weak reference to ``self`` to avoid a circular reference - (OmniBase → stage_list → OmniStage → closure → OmniBase) that - would prevent the instance from being freed by reference counting. - """ - weak_self = weakref.ref(self) - for stage in self.stage_list: - sid = stage.stage_id - - def make_rpc_checker(stage_id: int): - def rpc_checker(rpc_id: str) -> dict[str, Any] | None: - _self = weak_self() - if _self is None: - return None - # First check the shared dict - if stage_id in _self._rpc_results and rpc_id in _self._rpc_results[stage_id]: - return _self._rpc_results[stage_id].pop(rpc_id) - # In the sync path there is no background output handler, - # so drain the output queue ourselves and stash any - # non-RPC results back. - out_q = _self._stage_out_queues[stage_id] if stage_id < len(_self._stage_out_queues) else None - if out_q is not None: - import queue as _queue - - try: - while True: - item = out_q.get_nowait() - if isinstance(item, dict) and item.get("type") == "collective_rpc_result": - item_rpc_id = item.get("rpc_id") - if item_rpc_id == rpc_id: - return item - # Stash for another caller - if stage_id not in _self._rpc_results: - _self._rpc_results[stage_id] = {} - _self._rpc_results[stage_id][item_rpc_id] = item - else: - # Non-RPC item — put it back - out_q.put(item) - break - except _queue.Empty: - pass - return None - - return rpc_checker - - stage._rpc_result_checker = make_rpc_checker(sid) - - def collective_rpc( - self, - method: str | Callable[..., _R], - timeout: float | None = None, - args: tuple = (), - kwargs: dict[str, Any] | None = None, - ) -> list[_R]: - """Execute an RPC call on all stage workers. - - Args: - method: Name of the worker method to execute, or a callable. - timeout: Maximum time in seconds to wait for execution. - args: Positional arguments to pass to the worker method. - kwargs: Keyword arguments to pass to the worker method. - - Returns: - A list containing the results from each stage. - """ - results: list[_R] = [] - for stage in self.stage_list: - results.append( - stage.collective_rpc( - method=method, - timeout=timeout, - args=args, - kwargs=kwargs, - ) - ) - return results - - def sleep(self, level: int = 1) -> None: - """Put the engine to sleep to free up resources. - - Args: - level: Sleep level (1 = light sleep, higher = deeper sleep). - """ - self._is_sleeping = True - self.collective_rpc(method="sleep", args=(level,)) - - def wake_up(self, tags: list[str] | None = None) -> None: - """Wake the engine up from sleep. - - Args: - tags: Optional list of tags to selectively wake components. - """ - self._is_sleeping = False - self.collective_rpc(method="wake_up", args=(tags,)) - - def is_sleeping(self) -> bool: - """Check whether the engine is sleeping.""" - return getattr(self, "_is_sleeping", False) - - def add_lora(self, lora_request: LoRARequest) -> bool: - """Load a new LoRA adapter into the engine for future requests.""" - result = self.collective_rpc(method="add_lora", args=(lora_request,)) - return result[0][0] - - def start_profile(self, stages: list[int] | None = None) -> None: - """Start profiling for specified stages. - - Sends start_profile command to stage workers. Profiling must be enabled - via VLLM_TORCH_PROFILER_DIR environment variable. - - Args: - stages: List of stage IDs to start profiling. If None, starts - profiling for all stages that have profiling enabled. - - Example: - >>> # Profile all stages - >>> omni.start_profile() - >>> outputs = omni.generate(prompts, sampling_params) - >>> omni.stop_profile() - - >>> # Profile only stage 0 and 2 - >>> omni.start_profile(stages=[0, 2]) - """ - if stages is None: - stages = list(range(len(self.stage_list))) - - for stage_id in stages: - if stage_id < len(self.stage_list): - if not self._is_profiler_enabled(stage_id): - logger.info( - "[%s] Skipping start_profile for stage-%s: profiler config not set", - self._name, - stage_id, - ) - continue - try: - self.stage_list[stage_id].submit({"type": OmniStageTaskType.PROFILER_START}) - logger.info("[%s] Sent start_profile to stage-%s", self._name, stage_id) - except Exception as e: - logger.warning( - "[%s] Failed to send start_profile to stage-%s: %s", - self._name, - stage_id, - e, - ) - - def stop_profile(self, stages: list[int] | None = None) -> dict: - """ - Synchronously stop profiling for specified stages and collect - the file paths for traces and tables. - """ - if stages is None: - stages = list(range(len(self.stage_list))) - - all_results = {"traces": [], "tables": []} - - for stage_id in stages: - if stage_id < len(self.stage_list): - if not self._is_profiler_enabled(stage_id): - logger.info( - "[%s] Skipping stop_profile for stage-%s: profiler config not set", - self._name, - stage_id, - ) - continue - stage = self.stage_list[stage_id] - - # Check if the stage object has our new bridge method - if hasattr(stage, "stop_profile"): - logger.info("[%s] Requesting profile data collection from stage-%s", self._name, stage_id) - - # This is the blocking call that triggers the RPC chain - stage_data = stage.stop_profile() - - if isinstance(stage_data, dict): - # FIX: Handle both single key and list key formats - traces = stage_data.get("trace") or stage_data.get("traces") - tables = stage_data.get("table") or stage_data.get("tables") - - # Debug logging - logger.debug(f"[{self._name}] Stage-{stage_id} returned: {stage_data.keys()}") - if traces: - logger.debug(f"[{self._name}] Stage-{stage_id} traces type: {type(traces)}") - if tables: - logger.debug(f"[{self._name}] Stage-{stage_id} tables type: {type(tables)}") - - # Handle single strings - if traces: - if isinstance(traces, str): - all_results["traces"].append(traces) - elif isinstance(traces, list): - all_results["traces"].extend(traces) - - # Handle single strings - if tables: - if isinstance(tables, str): - all_results["tables"].append(tables) - elif isinstance(tables, list): - all_results["tables"].extend(tables) - else: - logger.warning(f"[{self._name}] Stage-{stage_id} returned no table data") - else: - logger.warning(f"[{self._name}] Stage-{stage_id} returned non-dict data: {type(stage_data)}") - else: - # Fallback for non-diffusion stages - logger.warning( - "[%s] Stage-%s does not support synchronous stop_profile. Falling back to async.", - self._name, - stage_id, - ) - stage.submit({"type": OmniStageTaskType.PROFILER_STOP}) - - # Final debug output - logger.info( - f"[{self._name}] Collected {len(all_results['traces'])} trace(s) and {len(all_results['tables'])} table(s)" - ) - - return all_results - - def close(self) -> None: - """Close all stage processes and clean up resources.""" - if hasattr(self, "_weak_finalizer"): - self._weak_finalizer() - - def _process_handshake_message(self, msg: Any) -> dict[str, Any]: - """Process incoming handshake message and generate response. - - Args: - msg: Decoded message from client - - Returns: - Response dictionary with ok status and either endpoints or error - """ - if not isinstance(msg, dict) or msg.get("type") != "handshake": - return {"ok": False, "error": "invalid handshake payload"} - - try: - stage_id = int(msg.get("stage_id")) - except (TypeError, ValueError) as e: - return {"ok": False, "error": f"invalid stage_id: {e}"} - - endpoints = self._handshake_endpoints.get(stage_id) - if endpoints is None: - return {"ok": False, "error": f"unknown stage_id: {stage_id}"} - - # Mark stage as seen and prepare success response - self._handshake_seen.add(stage_id) - in_endpoint, out_endpoint = endpoints - - logger.info( - "[%s] Handshake received from stage-%s", - self._name, - stage_id, - ) - - return { - "ok": True, - "in_endpoint": in_endpoint, - "out_endpoint": out_endpoint, - } - - def _run_handshake_server_loop(self) -> None: - """Main loop for handshake server - polls for messages and responds.""" - poller = zmq.Poller() - poller.register(self._zmq_handshake_socket, zmq.POLLIN) - - try: - while not self._handshake_stop.is_set(): - events = poller.poll(1000) - has_message = any(sock == self._zmq_handshake_socket and event == zmq.POLLIN for sock, event in events) - if not has_message: - continue - - msg = msgspec.msgpack.decode(self._zmq_handshake_socket.recv()) - response = msgspec.msgpack.encode(self._process_handshake_message(msg)) - self._zmq_handshake_socket.send(response) - finally: - poller.unregister(self._zmq_handshake_socket) - - def start_handshake_server(self) -> None: - """Start the ZMQ handshake server. - - The handshake server allows distributed stages to discover their - queue endpoints by querying the orchestrator with their stage_id. - Skips starting if the server is already running or ZMQ is not initialized. - """ - # Skip if already running or ZMQ not initialized - if self._handshake_thread is not None or self._zmq_ctx is None: - return - - # Skip if master address/port not configured - if not self._zmq_master_address or self._zmq_master_port is None: - return - - # Create server endpoint and socket - endpoint = get_engine_client_zmq_addr( - local_only=False, host=self._zmq_master_address, port=int(self._zmq_master_port) - ) - - self._handshake_stop = threading.Event() - self._zmq_handshake_socket = make_zmq_socket(self._zmq_ctx, endpoint, zmq.REP, bind=True, linger=5000) - - # Start server thread - self._handshake_thread = threading.Thread( - target=self._run_handshake_server_loop, daemon=True, name="zmq-handshake-server" - ) - self._handshake_thread.start() - - def _wait_for_handshakes(self, timeout: int = 120) -> int: - """Wait for handshakes from all expected stages. - - Args: - timeout: Timeout in seconds for waiting for handshakes. Default is 120s. - - Returns: - Remaining timeout in seconds after waiting for handshakes. - """ - total_stages = len(self.stage_configs) - expected = set(range(total_stages)) - {int(self._single_stage_id)} - if not expected: - return timeout - - deadline = time.time() + max(0, int(timeout)) - logger.info(f"[{self._name}] Waiting for handshakes from stages: {expected} (timeout: {timeout}s)") - - # NOTE: _handshake_seen may be updated from the handshake server thread. - # It is intentionally used here without additional locking because: - # - _handshake_seen only ever grows (stages are added but never removed), and - # - we only check membership and set inclusion relative to `expected`. - # Under these monotonic semantics and the CPython GIL, concurrent reads/writes - # are safe for this usage and cannot violate correctness: we may observe a - # slightly stale view, but the loop condition remains valid and eventually - # becomes true once all expected stages have handshaked or the timeout elapses. - while not expected.issubset(self._handshake_seen) and time.time() < deadline: - time.sleep(1.0) - - remaining_timeout = max(0, int(deadline - time.time())) - - if not expected.issubset(self._handshake_seen): - missing = sorted(expected - self._handshake_seen) - logger.warning( - f"[{self._name}] Handshake timeout: {len(self._handshake_seen)}/{len(expected)} " - f"stages completed handshake. Missing stages: {missing}" - ) - - return remaining_timeout - - @property - def _name(self) -> str: - return "OmniBase" - - @property - def is_async(self) -> bool: - return False - - class Omni(OmniBase): - """Unified entrypoint for both LLM and Diffusion models for better usability. - - Args: - model: Model name or path to load. - **kwargs: Arbitrary keyword arguments. - - stage_configs_path: Optional path to YAML file containing stage - configurations. If None, configurations are loaded from the model. - - log_stats: Whether to enable statistics logging - be written to files with stage-specific suffixes. - - stage_init_timeout: Per-stage init watchdog (seconds). Measured from - when the previous stage finished (possibly a prior Omni run with GPU - reuse/overlap) to when the current stage starts to initialize. - - shm_threshold_bytes: Threshold in bytes for using shared memory - for IPC. Objects larger than this threshold will use shared memory. - - worker_backend: Backend for worker processes. Default is "multi_process". - - ray_address: Address of Ray cluster for Ray backend, if using Ray backend. - - batch_timeout: Timeout in seconds for batching requests within a stage - - init_timeout: Timeout in seconds for waiting for all stages to initialize - - Additional keyword arguments passed to stage engines. - - Example: - >>> omni = Omni(model="Qwen/Qwen2.5-Omni-7B") - >>> outputs = omni.generate(prompts="Hello, world!", sampling_params_list=[SamplingParams()]) - >>> print(outputs) - """ + """Synchronous entrypoint for offline generation.""" - def __init__(self, model: str, **kwargs: Any) -> None: - super().__init__(model, **kwargs) - - # Register weak reference cleanup (called on garbage collection) - self._weak_finalizer = weakref.finalize( - self, - _weak_close_cleanup, - self.stage_list, - self._stage_in_queues, - self._stage_out_queues, - self._ray_pg, - self._zmq_ctx, - self._handshake_stop, - self._zmq_handshake_socket, - self._handshake_thread, - ) + def _set_final_only_for_llm_stages( + self, + sampling_params_list: Sequence[OmniSamplingParams], + ) -> list[OmniSamplingParams]: + """Return per-stage params with LLM stages forced to FINAL_ONLY.""" + effective_params: list[OmniSamplingParams] = [] + for stage_id, params in enumerate(sampling_params_list): + sp = copy.deepcopy(params) + stage_meta = self.engine.get_stage_metadata(stage_id) + if stage_meta.get("stage_type") != "diffusion" and hasattr(sp, "output_kind"): + sp.output_kind = RequestOutputKind.FINAL_ONLY + effective_params.append(sp) + return effective_params @overload def generate( @@ -1007,6 +45,7 @@ def generate( sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, *, py_generator: Literal[True], + use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None]: ... @overload @@ -1016,6 +55,7 @@ def generate( sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, *, py_generator: Literal[False] = False, + use_tqdm: bool | Callable[..., tqdm] = True, ) -> list[OmniRequestOutput]: ... def generate( @@ -1026,66 +66,26 @@ def generate( py_generator: bool = False, use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]: - """Generate outputs for the given prompts. - - Orchestrates the multi-stage pipeline based on YAML configuration. - Each stage will use OmniLLM or OmniDiffusion based on stage_type. - - Args: - prompts: Input prompt(s) for generation. - sampling_params_list: Optional list of per-stage parameters. - py_generator: Whether the returned result(s) are wrapped in a generator instead of a list. - use_tqdm: Whether to use tqdm progress bar - - Returns: - List of OmniRequestOutput objects, one for each input prompt. - Each output contains the stage_id, final_output_type, and - the request_output from the final stage. - - Raises: - ValueError: If sampling_params_list is None or has incorrect length. - """ - if sampling_params_list is None: - sampling_params_list = self.default_sampling_params_list - elif not isinstance(sampling_params_list, Sequence): - # TODO: After the recent introduction of BAGEL model (one LLM and one Diffusion), - # expect the text_to_image example code to run when only passing one OmniDiffusionSamplingParams - # This behavior may be confusing, and future PR can improve it. - per_stage_params: list[OmniSamplingParams] = [] - for default_stage_sp in self.default_sampling_params_list: - default_sp_type = default_stage_sp.__class__ - if default_sp_type == sampling_params_list.__class__: - per_stage_params.append(sampling_params_list) - else: - per_stage_params.append(default_stage_sp) - sampling_params_list = per_stage_params - + sampling_params_list = self.resolve_sampling_params_list(sampling_params_list) try: if py_generator: - return self._run_generation_with_generator(prompts, sampling_params_list) - else: - outputs = list(self._run_generation(prompts, sampling_params_list, use_tqdm)) - return outputs + return self._run_generation_with_generator(prompts, sampling_params_list, use_tqdm) + return list(self._run_generation(prompts, sampling_params_list, use_tqdm)) except Exception as e: - logger.exception("[Orchestrator] Failed to run generation: %s", e) - # Always close on exception to ensure cleanup + logger.exception("[Omni] Failed to run generation: %s", e) self.close() - raise e + raise def _run_generation_with_generator( self, prompts: OmniPromptType | Sequence[OmniPromptType], sampling_params_list: Sequence[OmniSamplingParams], + use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None]: - """Run generation through all stages in the pipeline and return a generator.""" - gen = self._run_generation(prompts, sampling_params_list) + gen = self._run_generation(prompts, sampling_params_list, use_tqdm) try: yield from gen - except Exception as e: - logger.exception("[Orchestrator] Failed to run generation: %s", e) - raise e finally: - # Cleanup when generator is exhausted or closed self.close() def _run_generation( @@ -1094,365 +94,94 @@ def _run_generation( sampling_params_list: Sequence[OmniSamplingParams], use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None]: - """Run generation through all stages in the pipeline.""" - logger.debug(f"[{self._name}] generate() called") - if sampling_params_list is None: - raise ValueError("sampling_params_list is required for pipelined generation") - - if len(sampling_params_list) != len(self.stage_list): - raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") - - for i, (stage, sp) in enumerate(zip(self.stage_list, sampling_params_list)): - ExpectedSPType = OmniDiffusionSamplingParams if stage.stage_type == "diffusion" else SamplingParams - if not isinstance(sp, ExpectedSPType): - raise ValueError( - f"Expected sampling parameters with type {ExpectedSPType} in stage {i}, got {sp.__class__}" - ) - - # Normalize prompts to a list for per-request iteration - # str is also Sequence but only test list-like containers here - if isinstance(prompts, str) or not isinstance(prompts, Sequence): - request_prompts: list[OmniPromptType] = [prompts] - else: - request_prompts = list(prompts) - - # Orchestrator keeps stage objects for input derivation - num_stages = len(self.stage_list) - - # Generate globally unique request IDs and map them to original prompts - request_ids = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] - request_id_to_prompt = {rid: p for rid, p in zip(request_ids, request_prompts)} - - # Track per-request start time for end-to-end timing - _req_start_ts: dict[str, float] = {} - _wall_start_ts: float = time.time() - - # CFG companion tracking (prompt expansion + lifecycle management) - cfg = CfgCompanionTracker( - prompt_expand_func=getattr(self.stage_list[0], "prompt_expand_func", None), - stage0_sampling_params=sampling_params_list[0], - ) - expanded_companions = cfg.expand_prompts(request_id_to_prompt) + try: + sampling_params_list = self._set_final_only_for_llm_stages(sampling_params_list) - # Determine the final stage for E2E stats (highest stage_id with final_output=True; fallback to last stage) - final_stage_id_to_prompt: dict[str, int] = {} - for rid, prompt in request_id_to_prompt.items(): - if isinstance(prompt, dict): - prompt_modalities = prompt.get("modalities", None) + if isinstance(prompts, str) or not isinstance(prompts, Sequence): + request_prompts: list[OmniPromptType] = [prompts] else: - prompt_modalities = None - final_stage_id_for_e2e = get_final_stage_id_for_e2e( - prompt_modalities, self.output_modalities, self.stage_list - ) - final_stage_id_to_prompt[rid] = final_stage_id_for_e2e - - # Metrics/aggregation helper - metrics = OrchestratorAggregator( - num_stages, - self.log_stats, - _wall_start_ts, - final_stage_id_to_prompt, - ) - - it = request_id_to_prompt.items() - if use_tqdm: - tqdm_func = use_tqdm if callable(use_tqdm) else tqdm - it = tqdm_func(it, desc="Adding requests") - - # Seed stage-0 queue with all requests - logger.debug(f"[{self._name}] Seeding {len(request_prompts)} requests into stage-0") - # Mark first input time for stage-0 - metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() - - for req_id, prompt in request_id_to_prompt.items(): - sp0 = sampling_params_list[0] # type: ignore[index] - task = { - "request_id": req_id, - "engine_inputs": prompt, - "sampling_params": sp0, - } - self.stage_list[0].submit(task) - _req_start_ts[req_id] = time.time() - logger.debug(f"[{self._name}] Enqueued request {req_id} to stage-0") - - # Submit CFG companion requests to stage-0 - if cfg.is_active: - for companion_id, companion_prompt in expanded_companions: - task = { - "request_id": companion_id, - "engine_inputs": companion_prompt, - "sampling_params": cfg.stage0_sampling_params, - } - self.stage_list[0].submit(task) - _req_start_ts[companion_id] = time.time() - logger.debug(f"[{self._name}] Enqueued CFG companion {companion_id} to stage-0") + request_prompts = list(prompts) + + if not request_prompts: + return + + request_ids = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] + req_start_ts: dict[str, float] = {} + wall_start_ts = time.time() + req_final_stage_ids: dict[str, int] = {} + + for req_id, prompt in zip(request_ids, request_prompts): + prompt_modalities = prompt.get("modalities", None) if isinstance(prompt, dict) else None + final_stage_id = self._compute_final_stage_id(prompt_modalities) + req_final_stage_ids[req_id] = final_stage_id + + metrics = OrchestratorMetrics( + self.num_stages, + self.log_stats, + wall_start_ts, + final_stage_id, + ) + req_state = ClientRequestState(req_id) + req_state.metrics = metrics + self.request_states[req_id] = req_state + + self.engine.add_request( + request_id=req_id, + prompt=prompt, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id, + ) + submit_ts = time.time() + req_state.metrics.stage_first_ts[0] = submit_ts + req_start_ts[req_id] = submit_ts - pbar = None - if use_tqdm: - tqdm_func = use_tqdm if callable(use_tqdm) else tqdm - pbar = tqdm_func( - total=len(request_prompts), - desc="Processed prompts", - dynamic_ncols=True, - postfix=(f"est. speed input: {0:.2f} unit/s, output: {0:.2f} unit/s"), - ) - # For each stage, forward results to next stage; collect finals at the end - # We pipeline by continually polling output queues in stage order - remaining_by_stage: list[int] = [len(request_prompts) + cfg.num_companions] + [0] * (num_stages - 1) - completed_requests = 0 - total_requests = len(request_prompts) + active_reqs = set(request_ids) + pbar = None + if use_tqdm: + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + pbar = tqdm_func(total=len(request_ids), desc="Processed prompts", dynamic_ncols=True) - logger.debug( - f"[{self._name}] Entering scheduling loop: total_requests={total_requests}, stages={num_stages}", - ) - while completed_requests < total_requests: - made_progress = False - for stage_id, stage in enumerate(self.stage_list): - result = stage.try_collect() - if result is None: - continue + while active_reqs: + msg = self.engine.try_get_output() - made_progress = True - req_id = result.get("request_id") - if "error" in result: - logger.error( - f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}", - ) - if cfg.is_companion(req_id) and stage_id == 0: - parent_id, parent_aborted = cfg.on_companion_error(req_id) - if parent_aborted: - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {parent_id} aborted due to " - f"companion failure ({completed_requests}/{total_requests})", - ) + should_continue, req_id, stage_id, req_state = self._handle_output_message(msg) + if should_continue: continue - if result.get("type") == "stage_ready": - # Only happens when stage is initialized slower than expected, - # so we wait for a short time and try again - time.sleep(0.05) + if req_id not in active_reqs: + logger.warning("[Omni] Received output for unknown/finished request_id=%s", req_id) continue - # CFG: companion requests only run through Stage-0 - if cfg.is_companion(req_id) and stage_id == 0: - ready_parent = cfg.on_companion_completed(req_id) - if ready_parent is not None: - success = cfg.forward_parent_with_cfg( - ready_parent, - cfg.pop_pending_parent(ready_parent), - self.stage_list, - self.connectors, - sampling_params_list, - request_id_to_prompt, - final_stage_id_to_prompt, - metrics, - remaining_by_stage, - ) - if not success: - cfg.consume_parent_failure(ready_parent) - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {ready_parent} dropped due to CFG forwarding failure " - f"({completed_requests}/{total_requests})", - ) + if req_state.metrics is None: continue - - engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") - # Mark last output time for this stage whenever we receive outputs - metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) - try: - _m: StageRequestStats = result.get("metrics") - if _m is not None: - # Accumulate generation time - metrics.accumulated_gen_time_ms[req_id][stage_id] += _m.stage_gen_time_ms - - # For diffusion stages, we also accumulate diffusion time - metrics.accumulate_diffusion_metrics(stage.stage_type, req_id, engine_outputs) - - metrics.on_stage_metrics(stage_id, req_id, _m, stage.final_output_type) - if pbar: - elapsed = pbar.format_dict["elapsed"] or 1e-6 - # Aggregate total tokens/images across all stages - total_out = sum(metrics.stage_total_tokens) - out_spd = total_out / elapsed - - modality = self.output_modalities[stage_id] - unit = "img" if modality == "image" else "tok" - - # Pre-calculate for cleaner string formatting - if metrics.e2e_count > 0: - avg_lat = metrics.e2e_total_ms / metrics.e2e_count - else: - avg_lat = 0 - - # Align with vLLM's wording "est. speed" using multi-line parentheses - pbar.postfix = ( - f"est. speed stage-{stage_id} {unit}/s: {out_spd:.2f}, avg e2e_lat: {avg_lat:.1f}ms" - ) - except Exception as e: - logger.exception( - f"[{self._name}] Failed to process metrics for stage {stage_id}, req {req_id}: {e}", - ) - logger.debug( - f"[{self._name}] Stage-{stage_id} completed request {req_id}; forwarding or finalizing", + output_to_yield = self._process_single_result( + result=msg, + stage_id=stage_id, + metrics=req_state.metrics, + req_start_ts=req_start_ts, + wall_start_ts=wall_start_ts, + final_stage_id_for_e2e=req_final_stage_ids[req_id], ) - stage.set_engine_outputs(engine_outputs) - - if getattr(stage, "final_output", False): - logger.debug( - f"[{self._name}] Request {req_id} finalized at stage-{stage_id}", - ) - - # End-to-end timing and time-per-token for final output - # (only once per request at the designated final stage) - try: - if stage_id == final_stage_id_to_prompt[req_id]: - metrics.on_finalize_request( - stage_id, - req_id, - _req_start_ts.get(req_id, _wall_start_ts), - ) - except Exception as e: - logger.exception( - f"[{self._name}] Finalize request handling error for req {req_id} at stage {stage_id}: {e}", - ) - output_to_yield = OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, # type: ignore[attr-defined] - request_output=engine_outputs, - ) - - # Record audio generated frames (only when finished) - try: - finished = ( - engine_outputs.finished - if hasattr(engine_outputs, "finished") - else ( - engine_outputs[0].finished - if isinstance(engine_outputs, list) - and engine_outputs - and hasattr(engine_outputs[0], "finished") - else False - ) - ) - if finished: - metrics.record_audio_generated_frames(output_to_yield, stage_id, req_id) - except Exception as e: - logger.exception( - f"[{self._name}] Failed to record audio metrics for req {req_id} at stage {stage_id}: {e}", - ) - + if output_to_yield is not None: yield output_to_yield - next_stage_id = stage_id + 1 - if next_stage_id <= final_stage_id_to_prompt[req_id]: - # CFG: if this parent has companions, defer forwarding - if cfg.has_companions(req_id) and stage_id == 0: - if cfg.is_parent_failed(req_id): - cfg.consume_parent_failure(req_id) - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {req_id} skipped CFG forwarding due to " - f"companion failure ({completed_requests}/{total_requests})", - ) - continue - - if cfg.all_companions_done(req_id): - success = cfg.forward_parent_with_cfg( - req_id, - {"engine_outputs": engine_outputs, "stage_id": stage_id}, - self.stage_list, - self.connectors, - sampling_params_list, - request_id_to_prompt, - final_stage_id_to_prompt, - metrics, - remaining_by_stage, - ) - if not success: - cfg.consume_parent_failure(req_id) - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {req_id} dropped due to CFG forwarding failure " - f"({completed_requests}/{total_requests})", - ) - else: - cfg.defer_parent(req_id, engine_outputs, stage_id) - continue - - next_stage: OmniStage = self.stage_list[next_stage_id] - try: - # Derive inputs for the next stage, record preprocess time - with metrics.stage_postprocess_timer(stage_id, req_id): - next_inputs = next_stage.process_engine_inputs( - self.stage_list, [request_id_to_prompt[req_id]] - ) - except Exception as e: - completed_requests += 1 - logger.exception( - f"[{self._name}] Process engine inputs error for req {req_id}" - f" at stage {next_stage_id}: {e} ({completed_requests}/{total_requests})", - ) - continue - sp_next = sampling_params_list[next_stage_id] # type: ignore[index] - - # Check if we have a connector for this edge - connector_key = (str(stage_id), str(next_stage_id)) - connector = self.connectors.get(connector_key) - sent_via_connector = False - if connector: - sent_via_connector = try_send_via_connector( - connector=connector, - stage_id=stage_id, - next_stage_id=next_stage_id, - req_id=req_id, - next_inputs=next_inputs, - sampling_params=sp_next, - original_prompt=request_id_to_prompt[req_id], - next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, - metrics=metrics, - ) - - if not sent_via_connector: - raise RuntimeError( - f"[{self._name}] Failed to send request {req_id} to stage-{next_stage_id} via connector. " - "Configure a connector for this edge or inspect connector logs for details." - ) - - logger.debug( - f"[{self._name}] Forwarded request {req_id} to stage-{next_stage_id}", - ) - remaining_by_stage[next_stage_id] += 1 - else: - completed_requests += 1 - if pbar: - final_mod = self.output_modalities[final_stage_id_to_prompt[req_id]] - pbar.unit = "img" if final_mod == "image" else "req" + if msg.get("finished"): + active_reqs.discard(req_id) + if pbar is not None: pbar.update(1) - logger.debug( - f"[{self._name}] Request {req_id} fully completed ({completed_requests}/{total_requests})", - ) - for timed_out_id in cfg.check_timeouts(): - completed_requests += 1 - logger.error( - f"[{self._name}] Parent {timed_out_id} timed out; counting as failed " - f"({completed_requests}/{total_requests})", - ) - - if not made_progress: - time.sleep(0.005) - logger.debug(f"[{self._name}] All requests completed") - - if pbar: - pbar.close() - - # Summarize and print stats - try: - metrics.build_and_log_summary() - except Exception as e: - logger.exception(f"[{self._name}] Failed to build/log summary: {e}") - - @property - def _name(self) -> str: - return "Orchestrator" + self._log_summary_and_cleanup(req_id) + except Exception: + if "active_reqs" in locals() and active_reqs: + self.abort(list(active_reqs)) + raise + finally: + if "pbar" in locals() and pbar is not None: + pbar.close() + + def abort(self, request_id: str | Iterable[str]) -> None: + request_ids = [request_id] if isinstance(request_id, str) else list(request_id) + self.engine.abort(request_ids) + for req_id in request_ids: + self.request_states.pop(req_id, None) + if self.log_stats: + logger.info("[Omni] Aborted request(s) %s", ",".join(request_ids)) diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py new file mode 100644 index 0000000000..fb5f3788db --- /dev/null +++ b/vllm_omni/entrypoints/omni_base.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import os +import time +import types +import weakref +from collections.abc import Sequence +from pprint import pformat +from typing import TYPE_CHECKING, Any, Literal + +import huggingface_hub +from vllm.logger import init_logger +from vllm.v1.engine.exceptions import EngineDeadError + +from vllm_omni.engine.async_omni_engine import AsyncOmniEngine +from vllm_omni.entrypoints.client_request_state import ClientRequestState +from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e +from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific +from vllm_omni.outputs import OmniRequestOutput + +if TYPE_CHECKING: + from vllm_omni.engine.arg_utils import OmniEngineArgs + +logger = init_logger(__name__) + + +def _weak_shutdown_engine(engine: AsyncOmniEngine) -> None: + """Best-effort engine cleanup for GC finalization.""" + try: + engine.shutdown() + except Exception: + pass + + +def omni_snapshot_download(model_id: str) -> str: + if os.path.exists(model_id): + return model_id + + # TODO: this is just a workaround for quickly use modelscope, we should support + # modelscope in weight loading feature instead of using `snapshot_download` + if os.environ.get("VLLM_USE_MODELSCOPE", False): + from modelscope.hub.snapshot_download import snapshot_download + + return snapshot_download(model_id) + + try: + download_weights_from_hf_specific( + model_name_or_path=model_id, + cache_dir=None, + allow_patterns=["*"], + require_all=True, + ) + except huggingface_hub.errors.RepositoryNotFoundError: + logger.warning("Repository not found for '%s'.", model_id) + + return model_id + + +OutputMessageHandleResult = tuple[Literal[True], None, None, None] | tuple[Literal[False], str, int, ClientRequestState] + + +class OmniBase: + """Shared runtime foundation for AsyncOmni and Omni.""" + + def __init__( + self, + model: str, + **kwargs: Any, + ) -> None: + engine_args: OmniEngineArgs | None = kwargs.pop("engine_args", None) + stage_init_timeout = kwargs.pop("stage_init_timeout", 300) + init_timeout = kwargs.pop("init_timeout", 600) + log_stats = kwargs.pop("log_stats", False) + async_chunk = kwargs.pop("async_chunk", False) + output_modalities = kwargs.pop("output_modalities", None) + + if "log_requests" in kwargs: + raise TypeError("`log_requests` has been removed in Omni/AsyncOmni. Use `log_stats`.") + model = omni_snapshot_download(model) + self.model = model + self.log_stats = log_stats + self.async_chunk = async_chunk + self.output_modalities = output_modalities or [] + + logger.info("[%s] Initializing with model %s", self.__class__.__name__, model) + st = time.time() + self.engine = AsyncOmniEngine( + model=model, + engine_args=engine_args, + init_timeout=init_timeout, + stage_init_timeout=stage_init_timeout, + **kwargs, + ) + self._shutdown_called = False + self._weak_finalizer = weakref.finalize(self, _weak_shutdown_engine, self.engine) + et = time.time() + logger.info("[%s] AsyncOmniEngine initialized in %.2f seconds", self.__class__.__name__, et - st) + self.async_chunk = bool(self.async_chunk or getattr(self.engine, "async_chunk", False)) + + self.request_states: dict[str, ClientRequestState] = {} + + self.default_sampling_params_list = self.engine.default_sampling_params_list + if not self.output_modalities: + self.output_modalities = [ + self.engine.get_stage_metadata(i).get("final_output_type") for i in range(self.engine.num_stages) + ] + + self._stage_meta_list = [ + types.SimpleNamespace(**self.engine.get_stage_metadata(i)) for i in range(self.engine.num_stages) + ] + + logger.info( + "[%s] Initialized with %s stages for model %s", + self.__class__.__name__, + self.engine.num_stages, + model, + ) + + @property + def num_stages(self) -> int: + return self.engine.num_stages + + @property + def is_running(self) -> bool: + return self.engine.is_alive() + + def check_health(self) -> None: + if not self.engine.is_alive(): + raise EngineDeadError("Orchestrator process is not alive") + + def resolve_sampling_params_list( + self, + sampling_params_list: Sequence[Any] | Any | None, + ) -> Sequence[Any]: + if sampling_params_list is None: + normalized = self.default_sampling_params_list + elif isinstance(sampling_params_list, Sequence) and not isinstance(sampling_params_list, (str, bytes)): + normalized = sampling_params_list + elif self.num_stages == 1: + normalized = [sampling_params_list] + else: + raise ValueError(f"Expected {self.num_stages} sampling params, got a single sampling params object") + if len(normalized) != self.num_stages: + raise ValueError(f"Expected {self.num_stages} sampling params, got {len(normalized)}") + return normalized + + def _log_summary_and_cleanup(self, request_id: str) -> None: + req_state = self.request_states.get(request_id) + try: + if req_state is None or req_state.metrics is None: + return + summary = req_state.metrics.build_and_log_summary() + logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) + except Exception: + logger.exception( + "[%s] Failed to build/log summary for req=%s", + self.__class__.__name__, + request_id, + ) + finally: + self.request_states.pop(request_id, None) + + def _compute_final_stage_id(self, output_modalities: list[str] | None) -> int: + return get_final_stage_id_for_e2e( + output_modalities, + self.output_modalities, + self._stage_meta_list, + ) + + def _process_stage_metrics_message(self, msg: dict[str, Any]) -> None: + req_id = msg.get("request_id") + req_state = self.request_states.get(req_id) + if req_state is None or req_state.metrics is None: + return + _m = msg.get("metrics") + if _m is None: + return + stage_id = msg.get("stage_id", 0) + req_state.metrics.on_stage_metrics(stage_id, req_id, _m) + submit_ts = msg.get("stage_submit_ts") + now = time.time() + if req_state.metrics.stage_first_ts[stage_id] is None: + req_state.metrics.stage_first_ts[stage_id] = submit_ts if submit_ts is not None else now + req_state.metrics.stage_last_ts[stage_id] = max(req_state.metrics.stage_last_ts[stage_id] or 0.0, now) + + def _handle_output_message( + self, + msg: dict[str, Any] | None, + ) -> OutputMessageHandleResult: + """Handle one Orchestrator output-queue message.""" + if msg is None: + return True, None, None, None + + msg_type = msg.get("type") + if msg_type == "stage_metrics": + self._process_stage_metrics_message(msg) + return True, None, None, None + + if msg_type == "error": + raise RuntimeError(msg.get("error", "Orchestrator returned an error message")) + + if msg_type != "output": + logger.warning("[%s] got unexpected msg type: %s", self.__class__.__name__, msg_type) + return True, None, None, None + + req_id = msg.get("request_id") + if req_id is None: + logger.warning("[%s] got output message without request_id", self.__class__.__name__) + return True, None, None, None + + stage_id = msg.get("stage_id") + if stage_id is None: + logger.warning("[%s] got output message without stage_id for req=%s", self.__class__.__name__, req_id) + return True, None, None, None + + req_state = self.request_states.get(req_id) + if req_state is None: + logger.debug( + "[%s] dropping output for unknown req %s", + self.__class__.__name__, + req_id, + ) + return True, None, None, None + + req_state.stage_id = stage_id + + return False, req_id, stage_id, req_state + + def _process_single_result( + self, + result: dict[str, Any], + stage_id: int, + metrics: OrchestratorMetrics, + req_start_ts: dict[str, float], + wall_start_ts: float, + final_stage_id_for_e2e: int, + ) -> OmniRequestOutput | None: + req_id = result.get("request_id") + engine_outputs = result.get("engine_outputs") + finished = engine_outputs.finished + + submit_ts = result.get("stage_submit_ts") + now = time.time() + if metrics.stage_first_ts[stage_id] is None: + metrics.stage_first_ts[stage_id] = submit_ts if submit_ts is not None else now + metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, now) + + _m = result.get("metrics") + if finished and _m is not None: + metrics.on_stage_metrics(stage_id, req_id, _m) + + stage_meta = self.engine.get_stage_metadata(stage_id) + if not stage_meta["final_output"]: + return None + + try: + rid_key = str(req_id) + if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done and finished: + metrics.on_finalize_request( + stage_id, + req_id, + req_start_ts.get(req_id, wall_start_ts), + ) + except Exception: + logger.exception("[%s] Finalize request handling error", self.__class__.__name__) + + images = getattr(engine_outputs, "images", []) if stage_meta["final_output_type"] == "image" else [] + return OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage_meta["final_output_type"], + request_output=engine_outputs, + images=images, + ) + + def shutdown(self) -> None: + logger.info("[%s] Shutting down", self.__class__.__name__) + self._shutdown_base() + + def close(self) -> None: + self.shutdown() + + def _shutdown_base(self) -> None: + if getattr(self, "_shutdown_called", False): + return + self._shutdown_called = True + finalizer = getattr(self, "_weak_finalizer", None) + if finalizer is not None and finalizer.alive: + finalizer.detach() + self.engine.shutdown() diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py deleted file mode 100644 index d70aa5c520..0000000000 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ /dev/null @@ -1,169 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import time -import uuid -from collections.abc import Sequence - -from vllm.logger import init_logger -from vllm.transformers_utils.config import get_hf_file_to_dict - -from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig -from vllm_omni.diffusion.diffusion_engine import DiffusionEngine -from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType -from vllm_omni.outputs import OmniRequestOutput - -logger = init_logger(__name__) - - -class OmniDiffusion: - """ - It is the main class to interact with vLLM-Omni diffusion models. - It acts as a high-level interface that prepares requests and - delegates the actual diffusion process to the DiffusionEngine. - - You can pass either an `OmniDiffusionConfig` via `od_config`, or - pass kwargs such as `model="Qwen/Qwen-Image"`, - which will be forwarded to `OmniDiffusionConfig.from_kwargs`. - """ - - def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs): - # Capture stage info from kwargs before they might be filtered out - stage_id = kwargs.get("stage_id") - engine_input_source = kwargs.get("engine_input_source") - cfg_kv_collect_func = kwargs.pop("cfg_kv_collect_func", None) - - if od_config is None: - od_config = OmniDiffusionConfig.from_kwargs(**kwargs) - elif isinstance(od_config, dict): - # If config is dict, check it too (priority to kwargs if both exist) - if stage_id is None: - stage_id = od_config.get("stage_id") - if engine_input_source is None: - engine_input_source = od_config.get("engine_input_source") - od_config = OmniDiffusionConfig.from_kwargs(**od_config) - - self.od_config = od_config - - # Inject stage info into omni_kv_config if present - if stage_id is not None: - self.od_config.omni_kv_config.setdefault("stage_id", stage_id) - if engine_input_source is not None: - self.od_config.omni_kv_config.setdefault("engine_input_source", engine_input_source) - - # Detect model class and load config - # Diffusers-style models expose `model_index.json` with `_class_name`. - # Non-diffusers models (e.g. Bagel, NextStep, GLM-Image) only have `config.json`, - # so we fall back to reading that and mapping model_type manually. - try: - config_dict = get_hf_file_to_dict( - "model_index.json", - od_config.model, - ) - - if config_dict is None: - raise FileNotFoundError("model_index.json not found") - - if od_config.model_class_name is None: - od_config.model_class_name = config_dict.get("_class_name", None) - od_config.update_multimodal_support() - - if od_config.model_class_name == "DreamIDOmniPipeline": - od_config.model_config = config_dict - else: - tf_config_dict = get_hf_file_to_dict( - "transformer/config.json", - od_config.model, - ) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) - - except (AttributeError, OSError, ValueError, FileNotFoundError): - cfg = get_hf_file_to_dict("config.json", od_config.model) - if cfg is None: - raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") - - # Map model_type or architecture to pipeline class - model_type = cfg.get("model_type") - architectures = cfg.get("architectures") or [] - pipeline_class = None - # Bagel/NextStep models don't have a model_index.json, so we set the pipeline class name manually - if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: - pipeline_class = "BagelPipeline" - elif model_type == "nextstep": - if od_config.model_class_name is None: - pipeline_class = "NextStep11Pipeline" - elif model_type == "glm-image" or "GlmImageForConditionalGeneration" in architectures: - pipeline_class = "GlmImagePipeline" - elif architectures and len(architectures) == 1: - pipeline_class = architectures[0] - - if pipeline_class is None: - raise ValueError(f"Unknown model type: {model_type}, architectures: {architectures}") - - if od_config.model_class_name is None: - od_config.model_class_name = pipeline_class - od_config.tf_model_config = TransformerConfig().from_dict(cfg) - od_config.update_multimodal_support() - - if cfg_kv_collect_func is not None: - od_config.cfg_kv_collect_func = cfg_kv_collect_func - - self.engine: DiffusionEngine = DiffusionEngine.make_engine(od_config) - - def generate( - self, - prompts: OmniPromptType | Sequence[OmniPromptType], - sampling_params: OmniDiffusionSamplingParams, - request_ids: list[str] = [], - ) -> list[OmniRequestOutput]: - _t0 = time.perf_counter() - if isinstance(prompts, (str, dict)): - prompts = [prompts] - else: - prompts = list(prompts) - - # Check if request_id is provided in kwargs - if len(request_ids) < len(prompts): - request_ids.extend(f"{i + len(request_ids)}_{uuid.uuid4()}" for i in range(len(prompts) - len(request_ids))) - - request = OmniDiffusionRequest(prompts, sampling_params, request_ids) - result = self._run_engine(request) - _t_ms = (time.perf_counter() - _t0) * 1000 - logger.info("OmniDiffusion.generate total: %.2f ms", _t_ms) - return result - - def _run_engine(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: - return self.engine.step(request) - - def close(self) -> None: - self.engine.close() - - def __del__(self): # pragma: no cover - best effort cleanup - try: - self.close() - except Exception: - pass - - def start_profile(self, trace_filename: str | None = None) -> None: - """Start profiling for the diffusion model. - - Args: - trace_filename: Optional base filename for trace files. - If None, a timestamp-based name will be generated. - """ - if hasattr(self, "engine") and self.engine: - self.engine.start_profile(trace_filename) - else: - raise RuntimeError("Diffusion engine not initialized") - - def stop_profile(self) -> dict: - """Stop profiling and return profiling results. - - Returns: - Dictionary containing paths to trace and table files. - """ - if hasattr(self, "engine") and self.engine: - return self.engine.stop_profile() - else: - raise RuntimeError("Diffusion engine not initialized") diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py deleted file mode 100644 index e06e50c83a..0000000000 --- a/vllm_omni/entrypoints/omni_llm.py +++ /dev/null @@ -1,275 +0,0 @@ -from collections.abc import Callable, Sequence -from typing import Any - -import cloudpickle -from pydantic import ValidationError -from tqdm import tqdm - -# External library imports (vLLM) -from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field -from vllm.entrypoints.llm import LLM -from vllm.inputs import ProcessorInputs, PromptType -from vllm.logger import init_logger -from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.plugins.io_processors import get_io_processor -from vllm.renderers.inputs.preprocess import parse_model_prompt -from vllm.usage.usage_lib import UsageContext -from vllm.utils.counter import Counter -from vllm.v1.engine.llm_engine import LLMEngine - -from vllm_omni.distributed.omni_connectors import initialize_orchestrator_connectors - -# Internal imports (our code) -from vllm_omni.engine.arg_utils import OmniEngineArgs -from vllm_omni.engine.input_processor import OmniInputProcessor, reinject_omni_fields -from vllm_omni.engine.output_processor import MultimodalOutputProcessor -from vllm_omni.entrypoints.utils import ( - filter_dataclass_kwargs, - resolve_model_config_path, -) - -logger = init_logger(__name__) - - -class OmniLLM(LLM): - """Main entry point for vLLM-Omni inference. - - This class extends the base vLLM LLM class with omni-specific - processors for handling multimodal inputs and outputs. It provides - configuration loading for multi-stage pipelines, while stage management - is handled by the Omni class. - - Args: - model: Model name or path to load - log_stats: Whether to enable statistics logging - compilation_config: Optional compilation configuration. Can be an - integer (compilation level), dict, or CompilationConfig instance. - hf_overrides: Optional HuggingFace model configuration overrides - structured_outputs_config: Optional structured outputs configuration. - Can be a dict or StructuredOutputsConfig instance. - init_sleep_seconds: Number of seconds to sleep between starting - each stage process during initialization (used by Omni class) - shm_threshold_bytes: Threshold in bytes for using shared memory - for IPC. Objects larger than this threshold will use shared memory. - batch_timeout: Timeout in seconds for batching requests within a stage - init_timeout: Timeout in seconds for waiting for all stages to initialize - **kwargs: Additional keyword arguments passed to the base LLM class - and engine - - Example: - >>> llm = OmniLLM(model="Qwen/Qwen2.5-Omni-7B") - >>> # Stage management is handled by Omni class - """ - - def __init__( - self, - model: str, - log_stats: bool = False, - compilation_config: int | dict[str, Any] | CompilationConfig | None = None, - hf_overrides: dict[str, Any] | None = None, - structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None, - init_sleep_seconds: int = 20, - shm_threshold_bytes: int = 65536, - batch_timeout: int = 10, - init_timeout: int = 300, - **kwargs: Any, - ): - """LLM constructor with omni-specific configuration loading.""" - # Store stage management parameters (used by Omni class) - self.worker_backend = kwargs.get("worker_backend", "multi_process") - self.ray_address = kwargs.get("ray_address", None) - self.batch_timeout = batch_timeout - self.log_stats: bool = bool(log_stats) - - # Resolve model config path for connectors - self.config_path = resolve_model_config_path(model) - - # Initialize connectors - self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( - self.config_path, worker_backend=self.worker_backend, shm_threshold_bytes=shm_threshold_bytes - ) - - # Initialize LLM engine - if "disable_log_stats" not in kwargs: - kwargs["disable_log_stats"] = True - - if "worker_cls" in kwargs: - worker_cls = kwargs["worker_cls"] - # if the worker_cls is not qualified string name, - # we serialize it using cloudpickle to avoid pickling issues - if isinstance(worker_cls, type): - kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) - - if "kv_transfer_config" in kwargs and isinstance(kwargs["kv_transfer_config"], dict): - from vllm.config.kv_transfer import KVTransferConfig - - raw_config_dict = kwargs["kv_transfer_config"] - try: - kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict) - except ValidationError as e: - logger.error( - "Failed to convert 'kv_transfer_config' dict to KVTransferConfig object. Dict: %s. Error: %s", - raw_config_dict, - e, - ) - raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e - - # Extract omni_kv_config from kwargs if present (injected by Omni) - omni_kv_config = kwargs.pop("omni_kv_config", None) - - if compilation_config is not None: - if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig(level=compilation_config) - elif isinstance(compilation_config, dict): - compilation_config_instance = CompilationConfig( - **{k: v for k, v in compilation_config.items() if is_init_field(CompilationConfig, k)} - ) - else: - compilation_config_instance = compilation_config - else: - compilation_config_instance = CompilationConfig() - - if structured_outputs_config is not None: - if isinstance(structured_outputs_config, dict): - structured_outputs_instance = StructuredOutputsConfig( - **{k: v for k, v in structured_outputs_config.items() if is_init_field(StructuredOutputsConfig, k)} - ) - else: - structured_outputs_instance = structured_outputs_config - else: - structured_outputs_instance = StructuredOutputsConfig() - - engine_args = OmniEngineArgs( - model=model, - compilation_config=compilation_config_instance, - structured_outputs_config=structured_outputs_instance, - omni_kv_config=omni_kv_config, - hf_overrides=hf_overrides or {}, - **filter_dataclass_kwargs(OmniEngineArgs, kwargs), - ) - - # Create the Engine (autoselects V0 vs V1) - self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) - self.llm_engine.output_processor = MultimodalOutputProcessor( - tokenizer=self.llm_engine.tokenizer, - log_stats=self.llm_engine.log_stats, - engine_core_output_type=engine_args.engine_output_type, - ) - self.llm_engine.input_processor = OmniInputProcessor(vllm_config=self.llm_engine.vllm_config) - self.engine_class = type(self.llm_engine) - - self.request_counter = Counter() - self.default_sampling_params: dict[str, Any] | None = None - - supported_tasks = self.llm_engine.get_supported_tasks() # type: ignore - - logger.info("Supported_tasks: %s", supported_tasks) - - self.supported_tasks = supported_tasks - - # Keep parity with vLLM's LLM initialization fields so inherited - # generate/chat preprocessing paths work as expected. - self.renderer = self.llm_engine.renderer - # Load the Input/Output processor plugin if any - io_processor_plugin = self.llm_engine.model_config.io_processor_plugin - self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin) - self.model_config = self.llm_engine.model_config - self.input_processor = self.llm_engine.input_processor - - # Parity with upstream LLM for pooling/classify entrypoints - chat_template = kwargs.get("chat_template", None) - from vllm.entrypoints.chat_utils import ChatTemplateConfig, load_chat_template - from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors - - self.chat_template = load_chat_template(chat_template) - self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template) - self.init_pooling_io_processors = init_pooling_io_processors( - supported_tasks=supported_tasks, - model_config=self.model_config, - renderer=self.renderer, - chat_template_config=self.chat_template_config, - ) - - # ------------------------------------------------------------------ - # Override upstream _preprocess_cmpl so that omni-specific fields - # (additional_information, prompt_embeds, …) survive the renderer's - # process_for_engine step which only copies standard vLLM keys. - # ------------------------------------------------------------------ - - def _preprocess_cmpl( - self, - prompts: Sequence[PromptType], - tokenization_kwargs: dict[str, Any] | None = None, - ) -> Sequence[ProcessorInputs]: - renderer = self.renderer - model_config = self.model_config - - parsed_prompts = [parse_model_prompt(model_config, prompt) for prompt in prompts] - tok_params = renderer.default_cmpl_tok_params.with_kwargs(**(tokenization_kwargs or {})) - results = renderer.render_cmpl(parsed_prompts, tok_params) - - reinject_omni_fields(results, parsed_prompts) - return results - - def close(self) -> None: - """Close resources. - - Note: Stage management is now handled by Omni class. - This method closes the LLM engine but not stages. - """ - # Close the LLM engine if it exists - if hasattr(self, "llm_engine") and self.llm_engine is not None: - if hasattr(self.llm_engine, "shutdown"): - self.llm_engine.shutdown() - - def __del__(self) -> None: # best-effort - try: - self.close() - except Exception as e: - logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True) - - def _run_engine( - self, output_type=None, *, use_tqdm: bool | Callable[..., tqdm] = True - ) -> list[RequestOutput | PoolingRequestOutput]: - # Initialize tqdm. - if use_tqdm: - num_requests = self.llm_engine.get_num_unfinished_requests() - tqdm_func = use_tqdm if callable(use_tqdm) else tqdm - pbar = tqdm_func( - total=num_requests, - desc="Processed prompts", - dynamic_ncols=True, - postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), - ) - - # Run the engine. - outputs: list[RequestOutput | PoolingRequestOutput] = [] - total_in_toks = 0 - total_out_toks = 0 - while self.llm_engine.has_unfinished_requests(): - step_outputs = self.llm_engine.step() - for output in step_outputs: - if output.finished: - outputs.append(output) - if use_tqdm: - if isinstance(output, RequestOutput): - # Calculate tokens only for RequestOutput - n = len(output.outputs) - assert output.prompt_token_ids is not None - total_in_toks += len(output.prompt_token_ids) * n - in_spd = total_in_toks / pbar.format_dict["elapsed"] - total_out_toks += sum(len(stp.token_ids) for stp in output.outputs) - out_spd = total_out_toks / pbar.format_dict["elapsed"] - pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s" - pbar.update(n) - else: - pbar.update(1) - if pbar.n == num_requests: - pbar.refresh() - - if use_tqdm: - pbar.close() - # Sort the outputs by the int part of request ID which is in format of 'int-uuid'. - # This is necessary because some requests may be finished earlier than - # its previous requests. - return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py deleted file mode 100644 index 1298322676..0000000000 --- a/vllm_omni/entrypoints/omni_stage.py +++ /dev/null @@ -1,1753 +0,0 @@ -""" -Stage manager for orchestrating multiple engines in vLLM-Omni. - -Enhanced to encapsulate per-stage process lifecycle and worker logic -(device setup, LLM init, batching, shared-memory IPC), while preserving -the original input processing utilities for cross-stage data wiring. -""" - -import asyncio -import fcntl -import importlib -import multiprocessing as mp -import os -import queue -import sys -import time -import traceback -import uuid -from collections.abc import Callable, Sequence -from contextlib import contextmanager -from dataclasses import fields -from typing import Any, Literal, TypeVar, cast - -from vllm import PromptType, RequestOutput -from vllm.inputs import TextPrompt -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams -from vllm.tokenizers import TokenizerLike -from vllm.usage.usage_lib import UsageContext -from vllm.v1.engine import EngineCoreOutput -from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.engine.llm_engine import LLMEngine - -from vllm_omni.diffusion.data import OmniDiffusionConfig -from vllm_omni.distributed.omni_connectors import build_stage_connectors -from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector -from vllm_omni.distributed.omni_connectors.connectors.base import OmniConnectorBase -from vllm_omni.distributed.ray_utils.utils import ( - get_ray_task_error, - is_ray_task_alive, - kill_ray_actor, - start_ray_actor, -) -from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs, OmniEngineArgs -from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion -from vllm_omni.entrypoints.async_omni_llm import AsyncOmniLLM -from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion -from vllm_omni.entrypoints.omni_llm import OmniLLM -from vllm_omni.entrypoints.stage_utils import ( - SHUTDOWN_TASK, - OmniStageTaskType, - _resolve_model_tokenizer_paths, - _to_dict, - is_profiler_task, - load_func_from_config, - maybe_dump_to_shm, - set_stage_devices, -) -from vllm_omni.entrypoints.utils import detect_pid_host, filter_dataclass_kwargs -from vllm_omni.entrypoints.zmq_utils import ( - ZmqQueue, - create_zmq_queue, -) -from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams, OmniTokensPrompt -from vllm_omni.metrics import count_tokens_from_outputs -from vllm_omni.outputs import OmniRequestOutput - -_R = TypeVar("_R") - -logger = init_logger(__name__) - - -@contextmanager -def _sequential_init_lock(engine_args: dict[str, Any], stage_init_timeout: int = 300): - """Acquire device locks for sequential init if NVML is unavailable. - - If process-scoped memory tracking is available (NVML works), stages can - safely initialize concurrently — each measures only its own GPU memory. - Otherwise, fall back to file-based locks to serialize initialization. - """ - from vllm_omni.worker.gpu_memory_utils import is_process_scoped_memory_available - - if is_process_scoped_memory_available() and detect_pid_host(): - logger.debug( - "NVML process-scoped memory available and PID host is available — concurrent init is safe, skipping locks" - ) - yield - return - else: - logger.debug( - "NVML unavailable or PID host is not available (usually inside a container, " - "--pid=host is not set in docker run command) — using sequential init locks" - ) - - from vllm_omni.platforms import current_omni_platform - - # Get all parallel sizes from engine_args or parallel_config (defaults to 1) - if "parallel_config" in engine_args: - parallel_config = engine_args["parallel_config"] - tensor_parallel_size = parallel_config.get("tensor_parallel_size", 1) - pipeline_parallel_size = parallel_config.get("pipeline_parallel_size", 1) - data_parallel_size = parallel_config.get("data_parallel_size", 1) - prefill_context_parallel_size = parallel_config.get("prefill_context_parallel_size", 1) - sequence_parallel_size = parallel_config.get("sequence_parallel_size", 1) - cfg_parallel_size = parallel_config.get("cfg_parallel_size", 1) - else: - tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) - pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) - data_parallel_size = engine_args.get("data_parallel_size", 1) - prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) - sequence_parallel_size = 1 - cfg_parallel_size = 1 - - num_devices_per_stage = ( - tensor_parallel_size - * pipeline_parallel_size - * data_parallel_size - * prefill_context_parallel_size - * sequence_parallel_size - * cfg_parallel_size - ) - - # Get physical device IDs from device control env var - device_control_env = current_omni_platform.device_control_env_var - visible_devices_str = os.environ.get(device_control_env) - physical_devices = [] - - if visible_devices_str: - try: - physical_devices = [int(x.strip()) for x in visible_devices_str.split(",") if x.strip()] - except (ValueError, IndexError): - pass - - if not physical_devices: - num_devices = current_omni_platform.get_device_count() - physical_devices = list(range(num_devices)) - - num_devices_to_lock = min(num_devices_per_stage, len(physical_devices)) - devices_to_lock = sorted(physical_devices[:num_devices_to_lock]) - - logger.debug( - "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d, SP=%d, CFG=%d; will lock %d devices: %s", - tensor_parallel_size, - pipeline_parallel_size, - data_parallel_size, - prefill_context_parallel_size, - sequence_parallel_size, - cfg_parallel_size, - num_devices_to_lock, - devices_to_lock, - ) - - # Acquire exclusive locks for all devices using fcntl.flock - wait_start = time.time() - acquired_lock_fds = [] - - for device_id in devices_to_lock: - lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock" - lock_acquired = False - - while not lock_acquired: - try: - lock_fd = os.open(lock_file, os.O_CREAT | os.O_RDWR, 0o644) - - try: - fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) - os.ftruncate(lock_fd, 0) - os.write(lock_fd, f"{os.getpid()}\n".encode()) - os.fsync(lock_fd) - lock_acquired = True - acquired_lock_fds.append(lock_fd) - logger.debug("Acquired exclusive lock for device %s", device_id) - except BlockingIOError: - os.close(lock_fd) - - if time.time() - wait_start > stage_init_timeout: - logger.warning( - "Timeout waiting for device %s initialization lock, proceeding anyway", - device_id, - ) - break - - time.sleep(0.1) - except OSError as e: - logger.debug( - "Failed to acquire lock for device %s: %s, continuing anyway", - device_id, - e, - ) - try: - os.close(lock_fd) - except (OSError, NameError): - pass - break - - # Set FD_CLOEXEC to prevent child processes from inheriting locks - for lock_fd in acquired_lock_fds: - try: - flags = fcntl.fcntl(lock_fd, fcntl.F_GETFD) - fcntl.fcntl(lock_fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) - except (OSError, ValueError): - pass - - try: - yield - finally: - for lock_fd in acquired_lock_fds: - try: - fcntl.flock(lock_fd, fcntl.LOCK_UN) - os.close(lock_fd) - logger.debug("Released initialization lock (fd=%s)", lock_fd) - except (OSError, ValueError): - pass - - -def _resolve_worker_cls(engine_args: dict[str, Any]) -> None: - worker_type = engine_args.get("worker_type", None) - if not worker_type: - return - worker_cls = engine_args.get("worker_cls") - if worker_cls is not None and worker_cls != "auto": - return - from vllm_omni.platforms import current_omni_platform - - worker_type = str(worker_type).lower() - if worker_type == "ar": - engine_args["worker_cls"] = current_omni_platform.get_omni_ar_worker_cls() - elif worker_type == "generation": - engine_args["worker_cls"] = current_omni_platform.get_omni_generation_worker_cls() - else: - raise ValueError(f"Unknown worker_type: {worker_type}") - - -def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: - """Build OmniDiffusionConfig kwargs from engine args.""" - od_config = engine_args.get("od_config", {}) - if not od_config: - od_config = {"model": model} - od_field_names = {f.name for f in fields(OmniDiffusionConfig)} - for key, value in engine_args.items(): - if key in od_field_names: - od_config[key] = value - return od_config - - -class OmniStage: - """Stage manager for orchestrating a single stage in the omni pipeline. - - Encapsulates per-stage process lifecycle and worker logic, including - device setup, LLM initialization, batching, and shared-memory IPC. - Preserves input processing utilities for cross-stage data wiring. - - Args: - stage_config: Stage configuration object containing engine arguments, - runtime settings, and stage-specific parameters - """ - - def __init__(self, stage_config: Any, stage_init_timeout: int = 300): - logger.debug(f"[OmniStage] stage_config: {stage_config}") - self.stage_config = stage_config - self.engine = None - self.async_engine = None - self.vllm_config = None - self.tokenizer = None - self.input_preprocessor = None - self.is_tracing_enabled = False - self.stage_id = stage_config.stage_id - self.engine_args = stage_config.engine_args - self.model_stage = stage_config.engine_args.model_stage - self.requires_multimodal_data = getattr(stage_config.runtime, "requires_multimodal_data", False) - # Support both 'input_sources' (new format) and 'engine_input_source' (legacy) - self.engine_input_source = ( - getattr(stage_config, "input_sources", None) or getattr(stage_config, "engine_input_source", []) or [] - ) - self.engine_output_type = getattr(stage_config.engine_args, "engine_output_type", None) - self.engine_outputs = None - self.is_comprehension = getattr(stage_config, "is_comprehension", False) - # Support for different stage types: "llm" (default) or "diffusion" - self.stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm") - if ( - "stage_id" in stage_config.engine_args - and stage_config.engine_args.stage_id != self.stage_id - and self.stage_id is not None - ): - stage_config.engine_args.stage_id = self.stage_id - if hasattr(stage_config, "custom_process_input_func"): - # Import the module specified in the config (already a full module path) - module_path, func_name = stage_config.custom_process_input_func.rsplit(".", 1) - module = importlib.import_module(module_path) - self.custom_process_input_func = getattr(module, func_name) - else: - self.custom_process_input_func = None - - self.prompt_expand_func = load_func_from_config(getattr(stage_config, "prompt_expand_func", None)) - self.final_output = getattr(stage_config, "final_output", False) - self.final_output_type = getattr(stage_config, "final_output_type", None) - self.tts_args = _to_dict(getattr(stage_config, "tts_args", {})) - default_sampling_params = getattr(stage_config, "default_sampling_params", {}) - # For LLM stage, this can directly be a SamplingParams-compatible dict; - # For diffusion stage, this only serves as default values for diffusion kwargs. - default_sampling_params = _to_dict(default_sampling_params) - # Further convert it to dataclass to check fields - try: - self.default_sampling_params = ( - SamplingParams if self.stage_type == "llm" else OmniDiffusionSamplingParams - )(**default_sampling_params) - except TypeError as error: - raise TypeError(f"Invalid default_sampling_params for stage {self.stage_id}: {error}") from error - # Runtime orchestration state (added) - self._in_q: mp.queues.Queue | ZmqQueue | str | None = None - self._out_q: mp.queues.Queue | ZmqQueue | str | None = None - self._proc: mp.Process | None = None - self._ray_actor: Any | None = None - self._ray_task_ref: Any | None = None - self._shm_threshold_bytes: int = 65536 - self._stage_init_timeout: int = stage_init_timeout - # Callback used by the orchestrator's output handler to stash - # collective_rpc results so that ``collective_rpc`` can retrieve - # them without competing for the output queue. - self._rpc_result_checker: Callable[[str], dict | None] | None = None - - def set_engine(self, engine: LLMEngine) -> None: - """Set the LLM engine for this stage. - - Args: - engine: LLMEngine instance to use for this stage - """ - self.engine = engine - - def set_async_engine(self, async_engine: AsyncLLM) -> None: - """Set the async LLM engine for this stage. - - Args: - async_engine: AsyncLLM instance to use for this stage - """ - self.async_engine = async_engine - - def set_vllm_config(self, vllm_config: Any) -> None: - """Set the vLLM configuration for this stage. - - Args: - vllm_config: VllmConfig instance received from worker process - """ - self.vllm_config = vllm_config - - def set_tokenizer(self, tokenizer: TokenizerLike) -> None: - """Set the tokenizer for this stage. - - Args: - tokenizer: Tokenizer instance received from worker process - """ - self.tokenizer = tokenizer - - def set_input_preprocessor(self, input_preprocessor: InputPreprocessor) -> None: - """Set the input preprocessor for this stage. - - Args: - input_preprocessor: InputPreprocessor instance received from worker process - """ - self.input_preprocessor = input_preprocessor - - def set_is_tracing_enabled(self, is_tracing_enabled: bool) -> None: - """Set whether tracing is enabled for this stage. - - Args: - is_tracing_enabled: Boolean indicating if tracing is enabled - """ - self.is_tracing_enabled = is_tracing_enabled - - def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: - """Set the engine outputs for this stage. - - Args: - engine_outputs: EngineCoreOutput from this stage's processing - """ - self.engine_outputs = engine_outputs - - # ----------------- New Orchestration APIs ----------------- - def attach_queues( - self, - in_q: mp.queues.Queue | ZmqQueue | str | None, - out_q: mp.queues.Queue | ZmqQueue | str | None, - ) -> None: - """Attach input and output queues for IPC communication. - - Args: - in_q: Input queue for receiving tasks from orchestrator (queue object or endpoint string) - out_q: Output queue for sending results to orchestrator (queue object or endpoint string) - """ - self._in_q = in_q - self._out_q = out_q - - def stop_profile(self) -> dict: - """Stop profiling by sending a signal to worker and waiting for response.""" - if self._in_q is None or self._out_q is None: - logger.warning(f"[Stage-{self.stage_id}] Queues not initialized, cannot stop profile.") - return {} - - logger.info(f"[Stage-{self.stage_id}] Sending PROFILER_STOP to worker...") - self.submit({"type": OmniStageTaskType.PROFILER_STOP}) - - # Wait for result from worker - try: - # Profiling stop might take time to flush files, give it 600s - response = self._out_q.get(timeout=600) - - if isinstance(response, dict): - if response.get("type") == "profiler_result": - return response.get("data", {}) - elif "error" in response: - logger.error(f"[Stage-{self.stage_id}] Profiler error: {response['error']}") - return {} - - # If we got something else (e.g. late generation result), we might lose it here, - # but usually profiling stop is called when generation is done. - logger.warning( - f"[Stage-{self.stage_id}] Received unexpected message while waiting for profiler: {response}" - ) - return {} - - except queue.Empty: - logger.error(f"[Stage-{self.stage_id}] Timeout waiting for profiler results.") - return {} - - def collective_rpc( - self, - method: str | Callable[..., _R], - timeout: float | None = None, - args: tuple = (), - kwargs: dict[str, Any] | None = None, - ) -> list[_R]: - """Execute an RPC call on all workers via the stage engine. - - Args: - method: Name of the worker method to execute, or a callable that - is serialized and sent to all workers to execute. - - If the method is a callable, it should accept an additional - ``self`` argument, in addition to the arguments passed in - ``args`` and ``kwargs``. The ``self`` argument will be the - worker object. - timeout: Maximum time in seconds to wait for execution. Raises a - :class:`TimeoutError` on timeout. ``None`` means wait - indefinitely. - args: Positional arguments to pass to the worker method. - kwargs: Keyword arguments to pass to the worker method. - - Returns: - A list containing the results from each worker. - - Note: - It is recommended to use this API to only pass control messages, - and set up data-plane communication to pass data. - """ - assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" - - # Submit collective_rpc task to worker - rpc_id = str(uuid.uuid4()) - self._in_q.put( - { - "type": OmniStageTaskType.COLLECTIVE_RPC, - "rpc_id": rpc_id, - "method": method, - "timeout": timeout, - "args": args, - "kwargs": kwargs, - } - ) - - start_time = time.time() - while True: - if timeout is not None and (time.time() - start_time) > timeout: - raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") - - # Check if result was already collected by the orchestrator's - # output handler (stored in a shared dict). - result = None - if self._rpc_result_checker is not None: - result = self._rpc_result_checker(rpc_id) - - if result is not None: - if result.get("type") == "collective_rpc_result": - if result.get("rpc_id") == rpc_id: - if "error" in result: - raise RuntimeError(f"collective_rpc failed: {result['error']}") - return result["result"] - - time.sleep(0.001) # Small sleep to avoid busy waiting - - def init_stage_worker( - self, - model: str, - *, - is_async: bool = False, - shm_threshold_bytes: int = 65536, - ctx: mp.context.BaseContext | None = None, - batch_timeout: int = 10, - connectors_config: dict | None = None, - worker_backend: str = "multi_process", - ignore_runtime_config: bool = False, - **kwargs: Any, - ) -> None: - """Initialize and start the stage worker process. - - Creates a worker process that runs the LLM engine for this stage. - The worker handles batching, generation, and IPC communication. - - Args: - model: Model name or path to load - is_async: Whether to use async engine (default: False) - shm_threshold_bytes: Threshold for using shared memory for IPC - ctx: Optional multiprocessing context (default: spawn) - batch_timeout: Timeout in seconds for batching requests - connectors_config: Configuration for stage connectors - worker_backend: Backend type ("multi_process" or "ray") - ignore_runtime_config: Whether to ignore runtime configuration (default: False) - **kwargs: Additional arguments (e.g. ray_placement_group) - - Raises: - AssertionError: If queues are not attached before calling this method - """ - assert self._in_q is not None and self._out_q is not None, "Queues must be attached before start_process" - - if worker_backend == "ray": - ray_placement_group = kwargs.get("ray_placement_group", None) - assert ray_placement_group is not None, "Ray placement group must be provided" - self._shm_threshold_bytes = sys.maxsize - else: - self._shm_threshold_bytes = shm_threshold_bytes - - ctx = ctx or mp.get_context("spawn") - # Prepare lightweight dict config for worker - engine_args = _to_dict(self.engine_args) - if ignore_runtime_config: - runtime_cfg = {} - else: - runtime_cfg = _to_dict(getattr(self.stage_config, "runtime", {})) - stage_payload: dict[str, Any] = { - "stage_id": self.stage_id, - "engine_args": engine_args, - "runtime": runtime_cfg, - "shm_threshold_bytes": self._shm_threshold_bytes, - "connectors_config": connectors_config or {}, - "stage_type": self.stage_type, - "engine_input_source": self.engine_input_source, - "cfg_kv_collect_func": getattr(self.stage_config, "cfg_kv_collect_func", None), - "final_output": self.final_output, - "final_output_type": self.final_output_type, - } - try: - old_env = os.environ.get("VLLM_LOGGING_PREFIX") - new_env = f"[Stage-{self.stage_id}] {'' if old_env is None else old_env}" - os.environ["VLLM_LOGGING_PREFIX"] = new_env - if worker_backend == "ray": - if is_async: - self._ray_actor, self._ray_task_ref = start_ray_actor( - _stage_worker_async_entry, - ray_placement_group, - self.stage_id, - model=model, - stage_payload=stage_payload, - in_q=self._in_q, - out_q=self._out_q, - batch_timeout=batch_timeout, - stage_init_timeout=self._stage_init_timeout, - ) - else: - self._ray_actor, self._ray_task_ref = start_ray_actor( - _stage_worker, - ray_placement_group, - self.stage_id, - model=model, - stage_payload=stage_payload, - in_q=self._in_q, - out_q=self._out_q, - batch_timeout=batch_timeout, - stage_init_timeout=self._stage_init_timeout, - ) - else: - if is_async: - self._proc = ctx.Process( - target=_stage_worker_async_entry, - args=( - model, - stage_payload, - self._in_q.endpoint if isinstance(self._in_q, ZmqQueue) else self._in_q, - self._out_q.endpoint if isinstance(self._out_q, ZmqQueue) else self._out_q, - batch_timeout, - self._stage_init_timeout, - ), - ) - else: - self._proc = ctx.Process( - target=_stage_worker, - args=( - model, - stage_payload, - self._in_q.endpoint if isinstance(self._in_q, ZmqQueue) else self._in_q, - self._out_q.endpoint if isinstance(self._out_q, ZmqQueue) else self._out_q, - batch_timeout, - self._stage_init_timeout, - ), - ) - self._proc.start() - finally: - if old_env is None: - os.environ.pop("VLLM_LOGGING_PREFIX", None) - else: - os.environ["VLLM_LOGGING_PREFIX"] = old_env - - def stop_stage_worker(self) -> None: - """Stop the stage worker process gracefully. - - Sends shutdown signal to the worker and waits for it to terminate. - If graceful shutdown fails, forcefully terminates the process. - Handles both multiprocessing Process and Ray Actor. - """ - if self._in_q is not None: - try: - self._in_q.put_nowait(SHUTDOWN_TASK) - except Exception as e: - # Queue may already be closed by the weak-ref cleanup - # (e.g. ZmqQueue socket closed) — this is expected. - logger.debug("Failed to send shutdown to in_q: %s", e) - close_fn = getattr(self._in_q, "close", None) - if callable(close_fn): - try: - close_fn() - except Exception: - pass - self._in_q = None - if self._out_q is not None: - close_fn = getattr(self._out_q, "close", None) - if callable(close_fn): - try: - close_fn() - except Exception: - pass - self._out_q = None - - if self._ray_actor is not None: - kill_ray_actor(self._ray_actor) - self._ray_actor = None - self._ray_task_ref = None - elif self._proc is not None: - try: - self._proc.join(timeout=5) - except Exception as e: - logger.debug("join() failed: %s", e) - if self._proc.is_alive(): - try: - self._proc.terminate() - except Exception as e: - logger.warning("terminate() failed: %s", e) - - def submit(self, payload: dict[str, Any]) -> None: - """Submit a task to the stage worker. - - Args: - payload: Dictionary containing task data (request_id, engine_inputs, - sampling_params, etc.) - """ - assert self._in_q is not None - - # [Omni] Inject global request_id into additional_information for cross-stage ID consistency - # This allows workers (like GPUARModelRunner) to use the global ID for side-channel - # operations like KV transfer, even if they use internal IDs for execution. - if "request_id" in payload and "engine_inputs" in payload: - req_id = payload["request_id"] - ein = payload["engine_inputs"] - - # Helper to inject into additional_information - def _inject_global_id(target_ein): - # OmniTokensPrompt is a TypedDict at runtime, so we treat it as a dict - if isinstance(target_ein, dict): - if "additional_information" not in target_ein: - target_ein["additional_information"] = {} - - # Ensure additional_information is a dict before assignment - # (in case it was somehow initialized as None or other type) - if target_ein["additional_information"] is None: - target_ein["additional_information"] = {} - - if isinstance(target_ein["additional_information"], dict): - # Wrap in list because OmniInputProcessor requires Tensor or list values - target_ein["additional_information"]["global_request_id"] = [str(req_id)] - - if isinstance(ein, list): - for item in ein: - _inject_global_id(item) - else: - _inject_global_id(ein) - - self._in_q.put(payload) - - def try_collect(self) -> dict[str, Any] | None: - """Try to collect a result from the stage worker without blocking. - - Returns: - Result dictionary if available, None otherwise. Result contains - request_id, engine_outputs (or engine_outputs_shm), and metrics. - """ - assert self._out_q is not None - try: - return self._out_q.get_nowait() - except queue.Empty: - pass - except Exception as e: - logger.error("Unexpected error when collecting OmniStage output queue:", exc_info=e) - self.stop_stage_worker() - raise - if self._proc is not None and not self._proc.is_alive(): - raise RuntimeError(f"OmniStage Worker process died unexpectedly with exit code {self._proc.exitcode}") - if self._ray_task_ref is not None and not is_ray_task_alive(self._ray_task_ref, timeout=0): - e = get_ray_task_error(self._ray_task_ref, timeout=0) - raise RuntimeError("OmniStage Ray actor died unexpectedly") from e - - def process_engine_inputs( - self, - stage_list: list[Any], - prompt: OmniTokensPrompt | TextPrompt = None, - *, - source_outputs_override: Any = None, - ) -> list[OmniTokensPrompt | TextPrompt]: - """Process engine inputs for this stage from upstream stage outputs. - - Derives inputs for this stage from outputs of upstream stages. - Uses engine_input_source configuration to determine which upstream - stage outputs to use. Supports custom processing functions. - - Args: - stage_list: List of all stages in the pipeline - prompt: Optional original prompt (for multimodal data preservation) - source_outputs_override: Use these outputs instead of reading from - the source stage's ``engine_outputs`` (for deferred CFG requests). - - Returns: - List of processed engine inputs ready for this stage - - Raises: - ValueError: If engine_input_source is empty or invalid - """ - if self.custom_process_input_func is None: - engine_inputs = [] - if len(self.engine_input_source) == 0: - raise ValueError("engine_input_source is empty") - source_stage_id = self.engine_input_source[0] - source_outputs = ( - source_outputs_override - if source_outputs_override is not None - else stage_list[source_stage_id].engine_outputs - ) - if not isinstance(prompt, list): - prompt = [prompt] - multi_modal_data = { - source_output.request_id: p.get("multi_modal_data", None) - for source_output, p in zip(source_outputs, prompt) - } - - for source_output in source_outputs: - engine_input = OmniTokensPrompt( - prompt_token_ids=source_output.outputs[0].token_ids, - multi_modal_data=( - multi_modal_data[source_output.request_id] - if self.requires_multimodal_data and multi_modal_data - else None - ), - ) - engine_inputs.append(engine_input) - return engine_inputs - - else: - engine_input_source = self.engine_input_source - if source_outputs_override is not None and engine_input_source: - # Temporarily swap engine_outputs so custom_process_input_func - # (which reads stage_list directly) sees the correct data. - _source_id = engine_input_source[0] - _orig_outputs = stage_list[_source_id].engine_outputs - stage_list[_source_id].engine_outputs = source_outputs_override - try: - return self.custom_process_input_func( - stage_list, engine_input_source, prompt, self.requires_multimodal_data - ) - finally: - stage_list[_source_id].engine_outputs = _orig_outputs - return self.custom_process_input_func( - stage_list, engine_input_source, prompt, self.requires_multimodal_data - ) - - -def _stage_worker( - model: str, - stage_payload: dict[str, Any], - in_q: mp.queues.Queue | ZmqQueue | str, - out_q: mp.queues.Queue | ZmqQueue | str, - batch_timeout: int = 10, - stage_init_timeout: int = 300, -) -> None: - """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" - # Use local aliases to avoid conflicts with global imports in worker process - logger.info(f"Starting stage worker with model: {model}") - import multiprocessing as _mp - import os as _os - import time as _time - - import zmq - - from vllm_omni.plugins import load_omni_general_plugins - - load_omni_general_plugins() - # IMPORTANT: Ensure vLLM's internal multiprocessing workers (e.g., GPUARWorker / - # GPUARModelRunner) are spawned with a fork-safe method. - # Mooncake / gRPC / RDMA and CUDA/NCCL can deadlock under fork-with-threads. - if _os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": - _os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - logger.info("[Stage] Set VLLM_WORKER_MULTIPROC_METHOD=spawn") - # Best-effort: also force python mp start method in this stage process. - # This may raise if already set; that's fine. - try: - _mp.set_start_method("spawn", force=True) - except RuntimeError: - pass - - stage_id = stage_payload["stage_id"] - engine_args = stage_payload.get("engine_args", {}) - runtime_cfg = stage_payload.get("runtime", {}) - shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) - connectors_config = stage_payload.get("connectors_config", {}) - stage_type: Literal["llm", "diffusion"] = stage_payload.get("stage_type", "llm") - - cfg_kv_collect_func = load_func_from_config(stage_payload.get("cfg_kv_collect_func")) - - if stage_type != "diffusion": - _resolve_worker_cls(engine_args) - - # Handle non-standard model directory structures (e.g., tokenizer in root, model in subdir) - model = _resolve_model_tokenizer_paths(model, engine_args) - - # Resolve ZMQ queue endpoints if needed - zmq_ctx = None - if isinstance(in_q, str) or isinstance(out_q, str): - zmq_ctx = zmq.Context() - if isinstance(in_q, str): - in_q = create_zmq_queue(zmq_ctx, in_q, zmq.PULL) - if isinstance(out_q, str): - out_q = create_zmq_queue(zmq_ctx, out_q, zmq.PUSH) - # When using ZMQ (cross-node IPC), disable SHM so data is sent inline. - shm_threshold_bytes = sys.maxsize - logger.info( - "[Stage-%s] ZMQ transport detected; disabling SHM IPC (shm_threshold_bytes set to maxsize)", - stage_id, - ) - - # Aggregates for running average - _agg_total_tokens = 0 - _agg_total_gen_time_ms = 0.0 - # Monotonic batch id per stage process for orchestrator dedup on time aggregation - _batch_seq = 0 - - # Device mapping - device_type = None - try: - from vllm_omni.platforms import current_omni_platform - - device_type = current_omni_platform.device_type - set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) - except Exception as e: - logger.warning("Device setup failed: %s", e) - - # Use sequential init locks only when NVML is unavailable - with _sequential_init_lock(engine_args, stage_init_timeout): - # Init engine based on stage_type - logger.debug( - "[Stage-%s] Initializing %s engine with args keys=%s", stage_id, stage_type, list(engine_args.keys()) - ) - if engine_args.get("async_chunk", False): - logger.debug("[Stage-%s] Async chunk enabled, injecting connectors config", stage_id) - stage_connector_spec = {} - for v in connectors_config.values(): - stage_connector_spec = dict(v.get("spec", {})) - break - engine_args["stage_connector_spec"] = stage_connector_spec - engine_args["stage_id"] = stage_id - if stage_type == "diffusion": - engine_args = filter_dataclass_kwargs(OmniDiffusionConfig, engine_args) - engine_args.pop("model_stage", None) - engine_args.pop("model", None) - stage_engine = OmniDiffusion( - model=model, - stage_id=stage_id, - engine_input_source=stage_payload.get("engine_input_source", []), - cfg_kv_collect_func=cfg_kv_collect_func, - **engine_args, - ) - else: - engine_args = filter_dataclass_kwargs(OmniEngineArgs, engine_args) - engine_args.pop("model", None) - # Default to LLM engine - stage_engine = OmniLLM(model=model, **engine_args) - - logger.debug("Engine initialized") - # Initialize OmniConnectors if configured - connectors: dict[tuple[str, str], OmniConnectorBase] | None = {} - if connectors_config: - connectors = build_stage_connectors( - stage_id=stage_id, - connectors_config=connectors_config, - ) - if connectors is None: - return - - # Signal readiness to orchestrator - try: - out_q.put({"type": "stage_ready", "stage_id": stage_id}) - except Exception: - pass - - max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) - logger.info(f"Max batch size: {max_batch_size}") - - def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: - """Handle profiler task locally in the worker process.""" - if task_type == OmniStageTaskType.PROFILER_START: - if stage_type == "diffusion": - try: - profile_dir = _os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles") - _os.makedirs(profile_dir, exist_ok=True) - trace_filename = f"stage_{stage_id}_diffusion_{int(_time.time())}" - stage_engine.start_profile(trace_filename=trace_filename) - logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id) - except Exception as e: - logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e) - else: - try: - stage_engine.start_profile() - logger.info("[Stage-%s] vLLM profiler started", stage_id) - except Exception as e: - logger.warning("[Stage-%s] Failed to start vLLM profiler: %s", stage_id, e) - return {} - - elif task_type == OmniStageTaskType.PROFILER_STOP: - if stage_type == "diffusion": - try: - # CRITICAL: Capture return value - result_data = stage_engine.stop_profile() - logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id) - return result_data - except Exception as e: - logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e) - return {} - else: - try: - stage_engine.stop_profile() - logger.info("[Stage-%s] vLLM profiler stopped", stage_id) - except Exception as e: - logger.warning("[Stage-%s] Failed to stop vLLM profiler: %s", stage_id, e) - return {} - return {} - - # Batch processing loop - while True: - task = in_q.get() - - _recv_dequeue_ts = _time.time() - task_type = task.get("type", OmniStageTaskType.GENERATE) - if task_type == OmniStageTaskType.SHUTDOWN: - logger.info("Received shutdown signal") - break - - # Handle profiler control commands - if is_profiler_task(task_type): - profiler_data = handle_profiler_task_local(task_type) - # If it was a STOP command, we must reply to the Orchestrator - if task_type == OmniStageTaskType.PROFILER_STOP: - out_q.put({"type": "profiler_result", "data": profiler_data}) - continue - - # Handle collective_rpc commands - if task_type == OmniStageTaskType.COLLECTIVE_RPC: - rpc_id = task.get("rpc_id") - rpc_method = task.get("method") - rpc_timeout = task.get("timeout") - rpc_args = task.get("args", ()) - rpc_kwargs = task.get("kwargs") or {} - try: - rpc_result = stage_engine.collective_rpc( - method=rpc_method, - timeout=rpc_timeout, - args=rpc_args, - kwargs=rpc_kwargs, - ) - out_q.put( - { - "type": "collective_rpc_result", - "rpc_id": rpc_id, - "stage_id": stage_id, - "result": rpc_result, - } - ) - except Exception as e: - logger.exception("[Stage-%s] collective_rpc failed: %s", stage_id, e) - out_q.put( - { - "type": "collective_rpc_result", - "rpc_id": rpc_id, - "stage_id": stage_id, - "error": str(e), - } - ) - continue - - batch_tasks: list[dict[str, Any]] = [task] - tasks_failed_to_add_to_batch: list[dict[str, Any]] = [] - start_time = _time.time() - if max_batch_size > 1: - while len(batch_tasks) < max_batch_size: - if not in_q.empty(): - extra = in_q.get_nowait() - if extra == SHUTDOWN_TASK: - in_q.put(SHUTDOWN_TASK) - break - # Handle profiler commands that arrive during batching - extra_type = extra.get("type") if isinstance(extra, dict) else None - if is_profiler_task(extra_type): - p_data = handle_profiler_task_local(extra_type) - if extra_type == OmniStageTaskType.PROFILER_STOP: - out_q.put({"type": "profiler_result", "data": p_data}) - continue - # Ensure that all tasks have the same sampling params - # If no, put them in a temporary container and add back to queue - # This should be always true, because user only calls omni.generate() once and it blocks - # User can only pass one sampling param object, but the list of prompts are separated. - if task.get("sampling_params") != extra.get("sampling_params"): - logger.warning( - """In offline mode, expect all prompts in one `omni.generate()` call to share same sampling params""" # noqa: E501 # line too long - f"""However, prompt {task.get("engine_inputs")} has sampling params {task.get("sampling_params")}, """ # noqa: E501 # line too long - f"""whereas the prompt {extra.get("engine_inputs")} has sampling params {extra.get("sampling_params")}.""" # noqa: E501 # line too long - """The two tasks cannot be combined in one batch request.""" - ) - tasks_failed_to_add_to_batch.append(extra) - else: - batch_tasks.append(extra) - end_time = _time.time() - duration = end_time - start_time - if duration > batch_timeout: - break - else: - continue - else: - end_time = _time.time() - duration = end_time - start_time - _time.sleep(0.05) - if duration > batch_timeout: - break - else: - continue - for task_to_readd in tasks_failed_to_add_to_batch: - in_q.put(task_to_readd) - # Ensure that the popped tasks are with identical sampling params. Take one of them. - batch_engine_sampling_params: OmniSamplingParams = batch_tasks[0]["sampling_params"] - - batch_request_ids: list[Any] = [] - batch_engine_inputs: list[OmniPromptType] = [] - _rx_bytes_by_rid: dict[Any, int] = {} - _rx_decode_ms_by_rid: dict[Any, float] = {} - _in_flight_ms_by_rid: dict[Any, float] = {} - for t in batch_tasks: - rid = t["request_id"] - try: - sent_ts = float(t.get("sent_ts", None)) if isinstance(t, dict) else None - if sent_ts is not None: - _in_flight_ms_by_rid[rid] = max(0.0, (_recv_dequeue_ts - sent_ts) * 1000.0) - else: - _in_flight_ms_by_rid[rid] = 0.0 - except Exception: - _in_flight_ms_by_rid[rid] = 0.0 - - # Resolve input data strictly via connectors if payload - # is larger than shm_threshold_bytes or using other connectors - ein, _rx_metrics = try_recv_via_connector( - task=t, - connectors=connectors, - stage_id=stage_id, - ) - # TODO: hack type annotation for now. - # A better way is to refine type annotation of connection and task/payloads, maybe using template types. - ein = cast(OmniPromptType | Sequence[OmniPromptType] | None, ein) - - if ein is None or _rx_metrics is None: - raise RuntimeError( - f"[Stage-{stage_id}] Missing connector payload for request {rid}. " - "Ensure connectors are configured for all incoming edges." - ) - - _rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0)) - _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) - - batch_request_ids.append(rid) - - if isinstance(ein, (dict, str)): - # For diffusion stage-0, ein might be a string prompt directly - batch_engine_inputs.append(ein) - elif isinstance(ein, Sequence): - batch_engine_inputs.extend(ein) - else: - # Other unknown types, append as-is - batch_engine_inputs.append(ein) - logger.debug( - "Received batch size=%d, request_ids=%s", - len(batch_tasks), - batch_request_ids, - ) - try: - _batch_seq += 1 - gen_outputs: list[OmniRequestOutput | RequestOutput] = [] - _gen_t0 = _time.time() - if stage_type == "diffusion": - stage_engine = cast(OmniDiffusion, stage_engine) - batch_engine_sampling_params = cast(OmniDiffusionSamplingParams, batch_engine_sampling_params) - # Diffusion generate returns results directly, not an iterator - diffusion_results = stage_engine.generate( - batch_engine_inputs, batch_engine_sampling_params, batch_request_ids - ) - gen_outputs.extend(diffusion_results) - # Assign request_ids if not present - for idx, result in enumerate(gen_outputs): - if not hasattr(result, "request_id") or result.request_id is None: - if idx < len(batch_request_ids): - result.request_id = batch_request_ids[idx] - else: - stage_engine = cast(OmniLLM, stage_engine) - batch_engine_sampling_params = cast(SamplingParams, batch_engine_sampling_params) - results = stage_engine.generate( - batch_engine_inputs, # type: ignore # silent complaints about list of subclassed TypedDict - batch_engine_sampling_params, - use_tqdm=False, - ) - gen_outputs.extend(results) - _gen_t1 = _time.time() - _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 - - # Group outputs per request id with fallback - req_to_outputs: dict[Any, list[Any]] = {rid: [] for rid in batch_request_ids} - unmapped: list[Any] = [] - for ro in gen_outputs: - rid = ro.request_id - if rid in req_to_outputs: - req_to_outputs[rid].append(ro) - else: - unmapped.append(ro) - if unmapped: - idx = 0 - for ro in unmapped: - target_rid = batch_request_ids[idx % len(batch_request_ids)] - ro.request_id = target_rid - req_to_outputs[target_rid].append(ro) - idx += 1 - - _agg_total_gen_time_ms += _gen_ms - - # Emit per-request results - for i, rid in enumerate(batch_request_ids): - r_outputs = req_to_outputs.get(rid, []) - _metrics = make_request_stats( - r_outputs, - _gen_ms, - int(_batch_seq), - int(len(batch_request_ids)), - float(_rx_decode_ms_by_rid.get(rid, 0.0)), - int(_rx_bytes_by_rid.get(rid, 0)), - float(_in_flight_ms_by_rid.get(rid, 0.0)), - ) - _agg_total_tokens += _metrics.num_tokens_out - if i == len(batch_request_ids) - 1: - _metrics.stage_stats = make_stage_stats(_agg_total_tokens, _agg_total_gen_time_ms) - else: - _metrics.stage_stats = None - try: - use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) - except Exception: - use_shm, payload = False, r_outputs - - try: - if use_shm: - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "engine_outputs_shm": payload, - "metrics": _metrics, - } - ) - else: - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "engine_outputs": payload, - "metrics": _metrics, - } - ) - except Exception: - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "engine_outputs": r_outputs, - "metrics": _metrics, - } - ) - except Exception as e: - logger.exception("Failed on batch %s: %s", batch_request_ids, e) - _tb = traceback.format_exc() - for rid in batch_request_ids: - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "error": str(e), - "error_tb": _tb, - } - ) - - -def _stage_worker_async_entry( - model: str, - stage_payload: dict[str, Any], - in_q: mp.queues.Queue | ZmqQueue | str, - out_q: mp.queues.Queue | ZmqQueue | str, - batch_timeout: int = 10, - stage_init_timeout: int = 300, -) -> None: - asyncio.run(_stage_worker_async(model, stage_payload, in_q, out_q, batch_timeout, stage_init_timeout)) - - -async def _stage_worker_async( - model: str, - stage_payload: dict[str, Any], - in_q: mp.queues.Queue | ZmqQueue | str, - out_q: mp.queues.Queue | ZmqQueue | str, - batch_timeout: int = 10, - stage_init_timeout: int = 300, -) -> None: - """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" - # Use local aliases to avoid conflicts with global imports in worker process - import multiprocessing as _mp - import os as _os - import time as _time - - import zmq - - from vllm_omni.plugins import load_omni_general_plugins - - load_omni_general_plugins() - # IMPORTANT: Ensure vLLM's internal multiprocessing workers (e.g., GPUARWorker / - # GPUARModelRunner) are spawned with a fork-safe method. - if _os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": - _os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - logger.info("[Stage-async] Set VLLM_WORKER_MULTIPROC_METHOD=spawn") - try: - _mp.set_start_method("spawn", force=True) - except RuntimeError: - pass - - stage_id = stage_payload["stage_id"] - engine_args = stage_payload.get("engine_args", {}) - runtime_cfg = stage_payload.get("runtime", {}) - shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) - connectors_config = stage_payload.get("connectors_config", {}) - stage_type = stage_payload.get("stage_type", "llm") - final_output = stage_payload.get("final_output", False) - final_output_type = stage_payload.get("final_output_type", None) - - cfg_kv_collect_func = load_func_from_config(stage_payload.get("cfg_kv_collect_func")) - # Handle non-standard model directory structures (e.g., tokenizer in root, model in subdir) - model = _resolve_model_tokenizer_paths(model, engine_args) - - if stage_type != "diffusion": - _resolve_worker_cls(engine_args) - - # Resolve ZMQ queue endpoints if needed - zmq_ctx = None - if isinstance(in_q, str) or isinstance(out_q, str): - zmq_ctx = zmq.Context() - if isinstance(in_q, str): - in_q = create_zmq_queue(zmq_ctx, in_q, zmq.PULL) - if isinstance(out_q, str): - out_q = create_zmq_queue(zmq_ctx, out_q, zmq.PUSH) - # When using ZMQ (cross-node IPC), disable SHM so data is sent inline. - shm_threshold_bytes = sys.maxsize - logger.info( - "[Stage-%s] ZMQ transport detected; disabling SHM IPC (shm_threshold_bytes set to maxsize)", - stage_id, - ) - - # Aggregates for running average - _agg_total_tokens = 0 - _agg_total_gen_time_ms = 0.0 - # Monotonic batch id per stage process for orchestrator dedup on time - # aggregation - _batch_seq = 0 - - # Device mapping - device_type = None - try: - from vllm_omni.platforms import current_omni_platform - - device_type = current_omni_platform.device_type - set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) - except Exception as e: - logger.warning("Device setup failed: %s", e) - - # Initialize OmniConnectors if configured to match sync worker behavior - connectors: dict[Any, Any] = {} - if connectors_config: - built_connectors = build_stage_connectors( - stage_id=stage_id, - connectors_config=connectors_config, - ) - if built_connectors is None: - return - connectors = built_connectors - - # Use sequential init locks only when NVML is unavailable - with _sequential_init_lock(engine_args, stage_init_timeout): - # Init engine based on stage_type - logger.debug( - "[Stage-%s] Initializing %s engine with args keys=%s", - stage_id, - stage_type, - list(engine_args.keys()), - ) - if engine_args.get("async_chunk", False): - logger.debug("[Stage-%s] Async chunk enabled, injecting connectors config", stage_id) - stage_connector_spec = {} - for v in connectors_config.values(): - stage_connector_spec = dict(v.get("spec", {})) - break - engine_args["stage_connector_spec"] = stage_connector_spec - engine_args["stage_id"] = stage_id - if stage_type == "diffusion": - # For diffusion, we need to extract diffusion-specific config - engine_args = filter_dataclass_kwargs(OmniDiffusionConfig, engine_args) - od_config = _build_od_config(engine_args, model) - - # Inject omni config for worker to access stage info - if "omni_kv_config" not in od_config: - od_config["omni_kv_config"] = {} - od_config["omni_kv_config"]["stage_id"] = stage_id - od_config["omni_kv_config"]["engine_input_source"] = stage_payload.get("engine_input_source", []) - - logger.debug(f"[Stage-%s] Initializing diffusion engine with config: {od_config}", stage_id) - _diffusion_kwargs = {k: v for k, v in engine_args.items() if k not in {"od_config", "model"}} - if cfg_kv_collect_func is not None: - _diffusion_kwargs["cfg_kv_collect_func"] = cfg_kv_collect_func - stage_engine = AsyncOmniDiffusion( - model=model, - od_config=od_config, - **_diffusion_kwargs, - ) - vllm_config = None # Diffusion doesn't use vllm_config - else: - engine_args = filter_dataclass_kwargs(AsyncOmniEngineArgs, engine_args) - engine_args.pop("model", None) - omni_engine_args = AsyncOmniEngineArgs(model=model, **engine_args) - usage_context = UsageContext.OPENAI_API_SERVER - vllm_config = omni_engine_args.create_engine_config(usage_context=usage_context) - stage_engine = AsyncOmniLLM.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - engine_args=omni_engine_args, - disable_log_stats=bool( - engine_args.get("disable_log_stats", False) or getattr(omni_engine_args, "disable_log_stats", False) - ), - ) - if hasattr(stage_engine, "log_stats") and stage_engine.log_stats: - - async def _force_log(): - try: - while True: - await asyncio.sleep(10.0) - await stage_engine.do_log_stats() - except asyncio.CancelledError: - pass - - log_stats_task = asyncio.create_task(_force_log()) - else: - log_stats_task = None - - # Don't keep the dummy data in memory (only for LLM engines) - if stage_type != "diffusion": - await stage_engine.reset_mm_cache() - logger.debug("[Stage-%s] Engine initialized", stage_id) - - async def handle_profiler_task_async(task_type: OmniStageTaskType) -> dict: - """Handle profiler task asynchronously for both LLM and diffusion stages.""" - if task_type == OmniStageTaskType.PROFILER_START: - if stage_type == "diffusion": - try: - profile_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles") - os.makedirs(profile_dir, exist_ok=True) - trace_filename = f"stage_{stage_id}_diffusion_{int(time.time())}" - await stage_engine.start_profile(trace_filename=trace_filename) - logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id) - except Exception as e: - logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e) - else: - try: - await stage_engine.start_profile() - logger.info("[Stage-%s] vLLM profiler started", stage_id) - except Exception as e: - logger.warning("[Stage-%s] Failed to start vLLM profiler: %s", stage_id, e) - return {} - - elif task_type == OmniStageTaskType.PROFILER_STOP: - result_data: dict = {} - if stage_type == "diffusion": - try: - trace_files = await stage_engine.stop_profile() - logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id) - if trace_files: - logger.info("Diffusion trace files: %s", trace_files) - result_data = trace_files - except Exception as e: - logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e) - else: - try: - await stage_engine.stop_profile() - logger.info("[Stage-%s] vLLM profiler stopped", stage_id) - except Exception as e: - logger.warning("[Stage-%s] Failed to stop vLLM profiler: %s", stage_id, e) - return result_data - return {} - - # Signal readiness to orchestrator and send vllm_config back to main process - try: - # Send vllm_config back to main process so it can be accessed via - # get_vllm_config(). This is needed because async_engine is only available - # in the worker process - - # input_preprocessor = await stage_engine.get_input_preprocessor() - stage_ready_payload = { - "type": "stage_ready", - "stage_id": stage_id, - "vllm_config": vllm_config, - "tokenizer": getattr(stage_engine, "tokenizer", None), - } - # Only add is_tracing_enabled for LLM engines - if stage_type != "diffusion": - stage_ready_payload["is_tracing_enabled"] = await stage_engine.is_tracing_enabled() - out_q.put(stage_ready_payload) - except Exception as e: - logger.warning("Failed to send stage ready signal: %s", e) - generation_out_q = asyncio.Queue() - - # Batch processing loop - _rx_bytes_by_rid: dict[Any, int] = {} - _rx_decode_ms_by_rid: dict[Any, float] = {} - _in_flight_ms_by_rid: dict[Any, float] = {} - - async def generation_single_request(task: dict[str, Any]): - _recv_dequeue_ts = _time.time() - rid = task["request_id"] - try: - sent_ts = float(task.get("sent_ts", None)) if isinstance(task, dict) else None - if sent_ts is not None: - _in_flight_ms_by_rid[rid] = max(0.0, (_recv_dequeue_ts - sent_ts) * 1000.0) - else: - _in_flight_ms_by_rid[rid] = 0.0 - except Exception: - _in_flight_ms_by_rid[rid] = 0.0 - try: - ein, _rx_metrics = try_recv_via_connector( - task=task, - connectors=connectors, - stage_id=stage_id, - ) - # TODO: hack type annotation for now. - # A better way is to refine type annotation of connection and task/payloads, maybe using template types. - ein = cast(OmniPromptType | Sequence[OmniPromptType] | None, ein) - - if ein is None or _rx_metrics is None: - raise RuntimeError( - f"[Stage-{stage_id}] Missing connector payload for request {rid}. " - "Ensure connectors are configured for all incoming edges." - ) - _rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0)) - _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) - - logger.debug("Received batch size=1, request_ids=%s", rid) - _gen_t0 = _time.time() - if isinstance(ein, Sequence) and not isinstance(ein, str): - ein = ein[0] - - if stage_type == "diffusion": - diffusion_sampling_params = cast(OmniDiffusionSamplingParams, task["sampling_params"]) - # AsyncOmniDiffusion.generate returns a single result, not an async generator - gen_output = await cast(AsyncOmniDiffusion, stage_engine).generate(ein, diffusion_sampling_params, rid) - _gen_t1 = _time.time() - _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 - await generation_out_q.put((rid, gen_output, _gen_ms)) - else: - ein = cast(PromptType, ein) - llm_sampling_params: SamplingParams = task["sampling_params"] - gen_output = None - async for res in cast(AsyncLLM, stage_engine).generate(ein, llm_sampling_params, rid): - gen_output = res - _gen_t1 = _time.time() - _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 - _gen_t0 = _gen_t1 - await generation_out_q.put((rid, gen_output, _gen_ms)) - except Exception as e: - logger.exception("Failed on request %s: %s", rid, e) - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "error": str(e), - } - ) - - _batch_gen_t0 = _time.time() - while True: - try: - task = in_q.get_nowait() - task_type = task.get("type", OmniStageTaskType.GENERATE) - if task_type == OmniStageTaskType.SHUTDOWN: - logger.debug("Received shutdown signal") - stage_engine.shutdown() - break - elif task_type == OmniStageTaskType.ABORT: - rid = task["request_id"] - asyncio.create_task(stage_engine.abort(rid)) - elif is_profiler_task(task_type): - profiler_data = await handle_profiler_task_async(task_type) - # Send result back to orchestrator for STOP command - if task_type == OmniStageTaskType.PROFILER_STOP: - out_q.put({"type": "profiler_result", "data": profiler_data}) - elif task_type == OmniStageTaskType.COLLECTIVE_RPC: - rpc_id = task.get("rpc_id") - rpc_method = task.get("method") - rpc_timeout = task.get("timeout") - rpc_args = task.get("args", ()) - rpc_kwargs = task.get("kwargs") or {} - try: - if stage_type == "diffusion": - # DiffusionEngine.collective_rpc is synchronous - loop = asyncio.get_event_loop() - rpc_result = await loop.run_in_executor( - None, - lambda: stage_engine.engine.collective_rpc( - method=rpc_method, - timeout=rpc_timeout, - args=rpc_args, - kwargs=rpc_kwargs, - ), - ) - else: - rpc_result = await cast(AsyncLLM, stage_engine).collective_rpc( - method=rpc_method, - timeout=rpc_timeout, - args=rpc_args, - kwargs=rpc_kwargs, - ) - out_q.put( - { - "type": "collective_rpc_result", - "rpc_id": rpc_id, - "stage_id": stage_id, - "result": rpc_result, - } - ) - except Exception as e: - logger.exception( - "[Stage-%s] collective_rpc failed: %s", - stage_id, - e, - ) - out_q.put( - { - "type": "collective_rpc_result", - "rpc_id": rpc_id, - "stage_id": stage_id, - "error": str(e), - } - ) - else: - asyncio.create_task(generation_single_request(task)) - - except queue.Empty: - await asyncio.sleep(0.001) - batch_request_outputs: list[Any] = [] - batch_request_ids: list[Any] = [] - _gen_ms_list = [] - batch_metrics: list[Any] = [] - while True: - try: - rid, gen_output, _gen_ms = generation_out_q.get_nowait() - _metrics = make_request_stats( - [gen_output], - _gen_ms, - int(_batch_seq), - 1, # temporarily set to 1 - float(_rx_decode_ms_by_rid.get(rid, 0.0)), - int(_rx_bytes_by_rid.get(rid, 0)), - float(_in_flight_ms_by_rid.get(rid, 0.0)), - ) - batch_metrics.append(_metrics) - batch_request_outputs.append(gen_output) - _gen_ms_list.append(_gen_ms) - batch_request_ids.append(rid) - _agg_total_tokens += _metrics.num_tokens_out - except asyncio.QueueEmpty: - await asyncio.sleep(0.001) - break - - if not batch_request_outputs: - continue - _batch_seq += 1 - - _batch_gen_t1 = _time.time() - _agg_total_gen_time_ms += (_batch_gen_t1 - _batch_gen_t0) * 1000 - _batch_gen_t0 = _batch_gen_t1 - for idx, metrics in enumerate(batch_metrics): - metrics.batch_size = len(batch_metrics) - if idx == len(batch_metrics) - 1: - metrics.stage_stats = make_stage_stats(_agg_total_tokens, _agg_total_gen_time_ms) - - logger.debug("Sending outputs to main process") - for rid, output, _gen_ms, _metrics in zip( - batch_request_ids, batch_request_outputs, _gen_ms_list, batch_metrics - ): - try: - r_outputs = [output_strip(output, final_output, final_output_type)] - use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) - if use_shm: - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "engine_outputs_shm": payload, - "metrics": _metrics, - } - ) - else: - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "engine_outputs": payload, - "metrics": _metrics, - } - ) - logger.debug(f"Enqueued req={rid}, use_shm={use_shm}, tokens_out={_metrics.num_tokens_out}") - except Exception as e: - logger.exception( - "Failed to enqueue result for request %s: %s", - rid, - e, - ) - out_q.put( - { - "request_id": rid, - "stage_id": stage_id, - "engine_outputs": r_outputs, - "metrics": _metrics, - } - ) - logger.debug("Enqueued result for request %s to downstream", rid) - if log_stats_task is not None: - log_stats_task.cancel() - logger.info("Stage worker exiting") - - -def count_prompt_tokens_from_outputs(engine_outputs: list[Any]) -> int: - """Count prompt tokens from engine outputs.""" - total = 0 - for _ro in engine_outputs: - try: - prompt_token_ids = getattr(_ro, "prompt_token_ids", None) - if prompt_token_ids is not None: - total += len(prompt_token_ids) - except Exception: - pass - return total - - -def make_request_stats( - req_output: list[Any], - stage_gen_time_ms: float, - batch_id: int, - batch_size: int, - rx_decode_time_ms: float, - rx_transfer_bytes: int, - rx_in_flight_time_ms: float, -): - from vllm_omni.metrics import StageRequestStats - - num_tokens_in = count_prompt_tokens_from_outputs(req_output) - num_tokens_out = count_tokens_from_outputs(req_output) - return StageRequestStats( - num_tokens_in=num_tokens_in, - num_tokens_out=num_tokens_out, - stage_gen_time_ms=stage_gen_time_ms, - batch_id=batch_id, - batch_size=batch_size, - rx_decode_time_ms=rx_decode_time_ms, - rx_transfer_bytes=rx_transfer_bytes, - rx_in_flight_time_ms=rx_in_flight_time_ms, - stage_stats=None, - ) - - -def make_stage_stats(_agg_total_tokens: int, _agg_total_gen_time_ms: float): - from vllm_omni.metrics import StageStats - - return StageStats(total_token=_agg_total_tokens, total_gen_time_ms=_agg_total_gen_time_ms) - - -def output_strip(r_output: RequestOutput | OmniRequestOutput, final_output: bool, final_output_type: str | None): - """ - Strip unnecessary multimodal outputs from stages results, - in order to: - - reduce memory usage - - reduce transfer & serialization overhead - """ - - # check multimodal data is required by stage output config. - if final_output and final_output_type != "text": - return r_output - - # If the request has already finished, should not be altered. - if getattr(r_output, "finished", False): - return r_output - - mm_output = getattr(r_output, "multimodal_output", None) - if mm_output is not None: - r_output.multimodal_output = {} - - custom_out = getattr(r_output, "_custom_output", None) - if custom_out is not None: - r_output._custom_output = {} - - outputs = getattr(r_output, "outputs", None) - if outputs is not None: - for out in outputs: - if getattr(out, "multimodal_output", None): - out.multimodal_output = {} - if getattr(out, "_custom_output", None): - out._custom_output = {} - - return r_output diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index c52038465a..c3c250fda7 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -111,10 +111,9 @@ from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo, ReferenceImage from vllm_omni.entrypoints.openai.storage import STORAGE_MANAGER from vllm_omni.entrypoints.openai.stores import VIDEO_STORE, VIDEO_TASKS +from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request from vllm_omni.entrypoints.openai.video_api_utils import decode_input_reference from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt -from vllm_omni.lora.request import LoRARequest -from vllm_omni.lora.utils import stable_lora_int_id logger = init_logger(__name__) router = APIRouter() @@ -163,6 +162,12 @@ def _remove_route_from_router( ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format" +async def _get_vllm_config(engine_client: EngineClient) -> Any: + if hasattr(engine_client, "get_vllm_config"): + return await engine_client.get_vllm_config() + return getattr(engine_client, "vllm_config", None) + + def _remove_route_from_app(app, path: str, methods: set[str] | None = None): """Remove a route from the app by path and optionally by methods. @@ -187,13 +192,31 @@ class _DiffusionServingModels: provide a lightweight fallback. """ + class _NullModelConfig: + def __getattr__(self, name): + return None + + class _Unsupported: + def __init__(self, name: str): + self.name = name + + def __call__(self, *args, **kwargs): + raise NotImplementedError(f"{self.name} is not supported in diffusion mode") + + def __getattr__(self, attr): + raise NotImplementedError(f"{self.name}.{attr} is not supported in diffusion mode") + def __init__(self, base_model_paths: list[BaseModelPath]) -> None: self._base_model_paths = base_model_paths + self.model_config = self._NullModelConfig() @property def base_model_paths(self) -> list[BaseModelPath]: return self._base_model_paths + def __getattr__(self, name): + return self._Unsupported(name) + async def show_available_models(self) -> ModelList: return ModelList( data=[ @@ -276,7 +299,7 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, logger.warning("Profiler endpoints are enabled. This should ONLY be used for local development!") app.include_router(profiler_router) - vllm_config = await engine_client.get_vllm_config() + vllm_config = await _get_vllm_config(engine_client) # Check if pure diffusion mode (vllm_config will be None) is_pure_diffusion = vllm_config is None @@ -334,7 +357,7 @@ async def build_async_omni( Args: args: Parsed command-line arguments containing model and configuration disable_frontend_multiprocessing: Optional flag to disable frontend - multiprocessing (deprecated in V1) + multiprocessing client_config: Optional client configuration dictionary Yields: @@ -372,7 +395,7 @@ async def build_async_omni_from_stage_config( Args: args: Parsed command-line arguments containing model and stage configs disable_frontend_multiprocessing: Flag to disable frontend multiprocessing - (deprecated in V1) + for compatibility with existing CLI options client_config: Optional client configuration dictionary Yields: @@ -383,16 +406,13 @@ async def build_async_omni_from_stage_config( otherwise from the model's default configuration. """ - # V1 AsyncLLM. if disable_frontend_multiprocessing: - logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.") + logger.warning("Ignoring --disable-frontend-multiprocessing for AsyncOmni runtime.") async_omni: EngineClient | None = None try: - # Convert args Namespace to kwargs dict for AsyncOmni to use kwargs = vars(args).copy() - # Remove model as it will be passed separately kwargs.pop("model", None) async_omni = AsyncOmni(model=args.model, **kwargs) @@ -423,14 +443,14 @@ async def omni_init_app_state( args: Parsed command-line arguments """ # Get vllm_config from engine_client (following 0.14.0 pattern) - vllm_config = await engine_client.get_vllm_config() + vllm_config = await _get_vllm_config(engine_client) # Detect if it's pure Diffusion mode (single stage and is Diffusion) is_pure_diffusion = False if hasattr(engine_client, "stage_configs") and engine_client.stage_configs: stage_configs = engine_client.stage_configs if len(stage_configs) == 1: - stage_type = stage_configs[0].get("stage_type", "llm") + stage_type = get_stage_type(stage_configs[0]) if stage_type == "diffusion": is_pure_diffusion = True logger.info("Detected pure diffusion mode (single diffusion stage)") @@ -483,7 +503,7 @@ async def omni_init_app_state( # LLM or multi-stage mode: use standard initialization logic if vllm_config is None: # Try to get vllm_config from engine_client - vllm_config = await engine_client.get_vllm_config() + vllm_config = await _get_vllm_config(engine_client) if vllm_config is None: logger.warning("vllm_config is None, some features may not work correctly") @@ -528,15 +548,13 @@ async def omni_init_app_state( # Try to initialize processors if vllm_config is available try: from vllm.plugins.io_processors import get_io_processor - - from vllm_omni.engine.input_processor import OmniInputProcessor + from vllm.v1.engine.input_processor import InputProcessor tokenizer = await engine_client.get_tokenizer() if tokenizer is not None: # Initialize input_processor - # OMNI: OmniInputProcessor creates tokenizer internally from vllm_config if not hasattr(engine_client, "input_processor") or engine_client.input_processor is None: - engine_client.input_processor = OmniInputProcessor( + engine_client.input_processor = InputProcessor( vllm_config=vllm_config, ) logger.info("Initialized input_processor for AsyncOmni") @@ -1167,7 +1185,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) HTTPException: For validation errors, missing engine, or generation failures """ # Get engine client (AsyncOmni) from app state - engine_client, model_name, stage_types = _get_engine_and_model(raw_request) + engine_client, model_name, stage_configs = _get_engine_and_model(raw_request) # Validate model field (warn if mismatch, don't error) if request.model is not None and request.model != model_name: @@ -1219,7 +1237,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) result = await _generate_with_async_omni( engine_client=engine_client, gen_params=gen_params, - stage_types=stage_types, + stage_configs=stage_configs, prompt=prompt, request_id=request_id, ) @@ -1294,7 +1312,7 @@ async def edit_images( OpenAI-compatible image edit endpoint. """ # 1. get engine and model - engine_client, model_name, stage_types = _get_engine_and_model(raw_request) + engine_client, model_name, stage_configs = _get_engine_and_model(raw_request) if model is not None and model != model_name: logger.warning( f"Model mismatch: request specifies '{model}' but server is running '{model_name}'. Using server model." @@ -1329,8 +1347,14 @@ async def edit_images( # 3.0 Init with system default values app_state_args = getattr(raw_request.app.state, "args", None) default_sample_param = getattr(app_state_args, "default_sampling_params", None) - # Currently only have one diffusion stage - diffusion_stage_id = [i for i, t in enumerate(stage_types) if t == "diffusion"][0] + # Currently only have one diffusion stage. + diffusion_stage_ids = [i for i, cfg in enumerate(stage_configs) if get_stage_type(cfg) == "diffusion"] + if not diffusion_stage_ids: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="No diffusion stage found in multi-stage pipeline.", + ) + diffusion_stage_id = diffusion_stage_ids[0] apply_stage_default_sampling_params( default_sample_param, gen_params, @@ -1377,7 +1401,7 @@ async def edit_images( result = await _generate_with_async_omni( engine_client=engine_client, gen_params=gen_params, - stage_types=stage_types, + stage_configs=stage_configs, prompt=prompt, request_id=request_id, ) @@ -1418,42 +1442,26 @@ async def edit_images( def _get_engine_and_model(raw_request: Request): # Get engine client (AsyncOmni) from app state engine_client: EngineClient | AsyncOmni | None = getattr(raw_request.app.state, "engine_client", None) - if engine_client is None or not hasattr(engine_client, "stage_list"): + if engine_client is None: raise HTTPException( status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, detail="Multi-stage engine not initialized. Start server with a multi-stage omni model.", ) - # Check if there's a diffusion stage + # Check if there's a diffusion stage. + # Prefer app state (compat layer populated at startup), then fall back to + # the engine client's stage configs for refactored AsyncOmni paths. stage_configs = getattr(raw_request.app.state, "stage_configs", None) + if not stage_configs: + stage_configs = getattr(engine_client, "stage_configs", None) if not stage_configs: raise HTTPException( status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, detail="Stage configs not found. Start server with a multi-stage omni model.", ) - # Check for diffusion stage and collect stage types - has_diffusion_stage = False - stage_types: list[str] = [] - for stage in stage_configs: - # Handle both dict and OmegaConf objects - stage_type = None - if isinstance(stage, dict): - stage_type = stage.get("stage_type", "llm") - elif hasattr(stage, "get"): - stage_type = stage.get("stage_type", "llm") - elif hasattr(stage, "stage_type"): - stage_type = stage.stage_type - else: - # Fallback: try to access as dict-like - try: - stage_type = stage["stage_type"] if "stage_type" in stage else "llm" - except (TypeError, KeyError): - stage_type = "llm" - - if stage_type == "diffusion": - has_diffusion_stage = True - stage_types.append(stage_type) + normalized_stage_configs = list(stage_configs) + has_diffusion_stage = any(get_stage_type(stage_cfg) == "diffusion" for stage_cfg in normalized_stage_configs) if not has_diffusion_stage: raise HTTPException( @@ -1463,12 +1471,13 @@ def _get_engine_and_model(raw_request: Request): # Get server's loaded model name serving_models = getattr(raw_request.app.state, "openai_serving_models", None) - if serving_models and hasattr(serving_models, "base_model_paths") and serving_models.base_model_paths: - model_name = serving_models.base_model_paths[0].name + base_model_paths = getattr(serving_models, "base_model_paths", None) if serving_models else None + if base_model_paths: + model_name = base_model_paths[0].name else: model_name = "unknown" - return engine_client, model_name, stage_types + return engine_client, model_name, normalized_stage_configs def _get_lora_from_json_str(lora_body): @@ -1486,81 +1495,63 @@ def _get_lora_from_json_str(lora_body): def _parse_lora_request(lora_body: dict[str, Any]): - if lora_body is not None: - if not isinstance(lora_body, dict): - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Invalid lora field: expected an object.", - ) - lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") - lora_path = ( - lora_body.get("local_path") - or lora_body.get("path") - or lora_body.get("lora_path") - or lora_body.get("lora_local_path") - ) - lora_scale = lora_body.get("scale") - if lora_scale is None: - lora_scale = lora_body.get("lora_scale") - lora_int_id = lora_body.get("int_id") - if lora_int_id is None: - lora_int_id = lora_body.get("lora_int_id") - if lora_int_id is None and lora_path: - lora_int_id = stable_lora_int_id(str(lora_path)) - - if not lora_name or not lora_path: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Invalid lora object: both name and path are required.", - ) - - return LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)), lora_scale - return None, None + try: + return parse_lora_request(lora_body) + except ValueError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=str(e), + ) from e async def _generate_with_async_omni( engine_client: AsyncOmni | Any, gen_params: Any, - stage_types: list[str], + stage_configs: list[Any], **kwargs, ): engine_client = cast(AsyncOmni, engine_client) result = None - stage_list = getattr(engine_client, "stage_list", None) - if isinstance(stage_list, list): - default_params_list: list[OmniSamplingParams] | None = getattr( - engine_client, "default_sampling_params_list", None + normalized_stage_configs = list(stage_configs) + if not normalized_stage_configs: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="Stage configs not found. Start server with a multi-stage omni model.", ) - if not isinstance(default_params_list, list): - default_params_list = [ - OmniDiffusionSamplingParams() if st == "diffusion" else SamplingParams() for st in stage_types - ] - else: - default_params_list = list(default_params_list) - if len(default_params_list) != len(stage_types): - default_params_list = ( - default_params_list - + [OmniDiffusionSamplingParams() if st == "diffusion" else SamplingParams() for st in stage_types] - )[: len(stage_types)] - - sampling_params_list: list[OmniSamplingParams] = [] - for idx, stage_type in enumerate(stage_types): - if stage_type == "diffusion": - sampling_params_list.append(gen_params) - else: - base_params = default_params_list[idx] - sampling_params_list.append(base_params) - - async for output in engine_client.generate( - sampling_params_list=sampling_params_list, - **kwargs, - ): - result = output + default_params_list: list[OmniSamplingParams] | None = getattr( + engine_client, + "default_sampling_params_list", + None, + ) + if not isinstance(default_params_list, list): + default_params_list = [] else: - result = await engine_client.generate( - sampling_params_list=[gen_params], - **kwargs, - ) + default_params_list = list(default_params_list) + + sampling_params_list: list[OmniSamplingParams] = [] + for idx, stage_cfg in enumerate(normalized_stage_configs): + stage_type = get_stage_type(stage_cfg) + if stage_type == "diffusion": + sampling_params_list.append(gen_params) + continue + + if idx < len(default_params_list): + default_stage_params = default_params_list[idx] + else: + default_stage_params = SamplingParams() + + if hasattr(default_stage_params, "clone"): + try: + default_stage_params = default_stage_params.clone() + except Exception: + pass + sampling_params_list.append(default_stage_params) + + async for output in engine_client.generate( + sampling_params_list=sampling_params_list, + **kwargs, + ): + result = output if result is None: raise HTTPException( diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index ada4a698b4..1fc3cdb362 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -84,8 +84,8 @@ from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio +from vllm_omni.entrypoints.openai.utils import parse_lora_request from vllm_omni.lora.request import LoRARequest -from vllm_omni.lora.utils import stable_lora_int_id from vllm_omni.outputs import OmniRequestOutput if TYPE_CHECKING: @@ -326,6 +326,8 @@ async def create_chat_completion( tprompt: OmniTextPrompt = {"prompt": extracted_prompt} if is_img2img: tprompt["modalities"] = ["img2img"] + else: + tprompt["modalities"] = ["image"] if negative_prompt is not None: tprompt["negative_prompt"] = negative_prompt # GLM-Image's _call_hf_processor expects target_h/target_w in mm_processor_kwargs @@ -626,10 +628,10 @@ def _to_sampling_params_list(self, sampling_params_list: list[dict]) -> list[Sam return final_sampling_params_list def _get_comprehension_stage_index(self) -> int: - for idx, stage in enumerate(self.engine_client.stage_list): + for idx, stage in enumerate(self.engine_client.stage_configs): if stage.is_comprehension: return idx - raise ValueError("No comprehension stage (is_comprehension=True) found in stage_list") + raise ValueError("No comprehension stage (is_comprehension=True) found in stage configs") # OpenAI API standard sampling parameters that can be safely overridden. # These are the most commonly used parameters with compatible types @@ -2104,27 +2106,11 @@ async def _create_diffusion_chat_completion( # Parse per-request LoRA (works for both AsyncOmniDiffusion and AsyncOmni). if lora_body and isinstance(lora_body, dict): try: - lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") - lora_path = ( - lora_body.get("local_path") - or lora_body.get("path") - or lora_body.get("lora_path") - or lora_body.get("lora_local_path") - ) - # using "or" directly here may be buggy if `scale=0` - lora_scale = lora_body.get("scale") - if lora_scale is None: - lora_scale = lora_body.get("lora_scale") - lora_int_id = lora_body.get("int_id") - if lora_int_id is None: - lora_int_id = lora_body.get("lora_int_id") - if lora_int_id is None and lora_path: - lora_int_id = stable_lora_int_id(str(lora_path)) - if lora_name and lora_path: - lora_req = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: gen_params.lora_request = lora_req if lora_scale is not None: - gen_params.lora_scale = float(lora_scale) + gen_params.lora_scale = lora_scale except Exception as e: # pragma: no cover - safeguard logger.warning("Failed to parse LoRA request: %s", e) @@ -2152,8 +2138,7 @@ async def _create_diffusion_chat_completion( # Generate image # Handle both AsyncOmniDiffusion (returns OmniRequestOutput) and AsyncOmni (returns AsyncGenerator) - if hasattr(self._diffusion_engine, "stage_list"): - # AsyncOmni: iterate through async generator to get final output + if isinstance(self._diffusion_engine, AsyncOmni): diffusion_engine = cast(AsyncOmni, self._diffusion_engine) result = None async for output in diffusion_engine.generate( diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 3b5d67542b..6c2928e627 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -197,12 +197,9 @@ def _load_codec_frame_rate(self) -> float | None: return None def _find_tts_stage(self): - """Find and return the TTS stage from the stage list, or None if not found.""" - stage_list = getattr(self.engine_client, "stage_list", None) - if stage_list is None: - return None - for stage in stage_list: - if getattr(stage, "model_stage", None) in _TTS_MODEL_STAGES: + """Find and return the TTS stage config, or None if not found.""" + for stage in self.engine_client.stage_configs: + if stage.engine_args.model_stage in _TTS_MODEL_STAGES: return stage return None @@ -496,13 +493,7 @@ async def delete_voice(self, name: str) -> bool: def _is_tts_model(self) -> bool: """Check if the current model is a supported TTS model.""" - stage_list = getattr(self.engine_client, "stage_list", None) - if stage_list: - for stage in stage_list: - model_stage = getattr(stage, "model_stage", None) - if model_stage in _TTS_MODEL_STAGES: - return True - return False + return any(stage.engine_args.model_stage in _TTS_MODEL_STAGES for stage in self.engine_client.stage_configs) def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: """Validate TTS request parameters. Returns error message or None.""" diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py index 6a16e7c701..48ffd49567 100644 --- a/vllm_omni/entrypoints/openai/serving_video.py +++ b/vllm_omni/entrypoints/openai/serving_video.py @@ -19,10 +19,9 @@ VideoGenerationRequest, VideoGenerationResponse, ) +from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request from vllm_omni.entrypoints.openai.video_api_utils import encode_video_base64 from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt -from vllm_omni.lora.request import LoRARequest -from vllm_omni.lora.utils import stable_lora_int_id logger = init_logger(__name__) @@ -161,38 +160,20 @@ async def generate_videos( @staticmethod def _apply_lora(lora_body: Any, gen_params: OmniDiffusionSamplingParams) -> None: - if lora_body is None: - return - if not isinstance(lora_body, dict): - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Invalid lora field: expected an object.", - ) - lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") - lora_path = ( - lora_body.get("local_path") - or lora_body.get("path") - or lora_body.get("lora_path") - or lora_body.get("lora_local_path") - ) - lora_scale = lora_body.get("scale") - if lora_scale is None: - lora_scale = lora_body.get("lora_scale") - lora_int_id = lora_body.get("int_id") - if lora_int_id is None: - lora_int_id = lora_body.get("lora_int_id") - if lora_int_id is None and lora_path: - lora_int_id = stable_lora_int_id(str(lora_path)) - - if not lora_name or not lora_path: + try: + lora_request, lora_scale = parse_lora_request(lora_body) + except ValueError as e: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, - detail="Invalid lora object: both name and path are required.", - ) + detail=str(e), + ) from e + + if lora_request is None: + return - gen_params.lora_request = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) + gen_params.lora_request = lora_request if lora_scale is not None: - gen_params.lora_scale = float(lora_scale) + gen_params.lora_scale = lora_scale async def _run_generation( self, @@ -200,44 +181,26 @@ async def _run_generation( gen_params: OmniDiffusionSamplingParams, request_id: str, ) -> Any: - has_stage_list = hasattr(self._engine_client, "stage_list") - logger.info( - "Video generation routing: stage_configs=%s, has_stage_list=%s, engine_type=%s", - "present" if self._stage_configs else "missing", - has_stage_list, - type(self._engine_client).__name__, - ) stage_configs = self._stage_configs or getattr(self._engine_client, "stage_configs", None) if not stage_configs: - if not hasattr(self._engine_client, "stage_list"): + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="Stage configs not found. Start server with an omni diffusion model.", + ) + + # Video generation endpoint only supports diffusion stages. + for stage in stage_configs: + stage_type = get_stage_type(stage) + if stage_type != "diffusion": raise HTTPException( status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, - detail="Stage configs not found. Start server with an omni diffusion model.", + detail=f"Video generation only supports diffusion stages, found '{stage_type}' stage.", ) - # Video generation endpoint only supports diffusion stages. - if stage_configs: - for stage in stage_configs: - # Extract stage_type: dicts and OmegaConf objects use .get(), others use getattr - if hasattr(stage, "get"): - stage_type = stage.get("stage_type", "llm") - else: - stage_type = getattr(stage, "stage_type", "llm") - - if stage_type != "diffusion": - raise HTTPException( - status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, - detail=f"Video generation only supports diffusion stages, found '{stage_type}' stage.", - ) - # Common generation logic for both paths engine_client = cast(AsyncOmni, self._engine_client) - stage_list = getattr(engine_client, "stage_list", None) - if isinstance(stage_list, list): - sampling_params_list: list[OmniSamplingParams] = [gen_params for _ in stage_list] - else: - sampling_params_list = [gen_params] + sampling_params_list: list[OmniSamplingParams] = [gen_params for _ in stage_configs] result = None async for output in engine_client.generate( diff --git a/vllm_omni/entrypoints/openai/utils.py b/vllm_omni/entrypoints/openai/utils.py new file mode 100644 index 0000000000..84b28ef5b1 --- /dev/null +++ b/vllm_omni/entrypoints/openai/utils.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any + +from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id + + +def get_stage_type(stage_cfg: Any) -> str: + """Best-effort stage type resolver across dict/omegaconf/object configs.""" + if isinstance(stage_cfg, dict): + return stage_cfg.get("stage_type", "llm") + if hasattr(stage_cfg, "get"): + try: + return stage_cfg.get("stage_type", "llm") + except Exception: + pass + return getattr(stage_cfg, "stage_type", "llm") + + +def parse_lora_request(lora_body: Any) -> tuple[LoRARequest | None, float | None]: + """Parse a request-level LoRA object into a LoRARequest and optional scale. + + Raises: + ValueError: If the object shape is invalid or required fields are missing. + """ + if lora_body is None: + return None, None + + if not isinstance(lora_body, dict): + raise ValueError("Invalid lora field: expected an object.") + + lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") + lora_path = ( + lora_body.get("local_path") + or lora_body.get("path") + or lora_body.get("lora_path") + or lora_body.get("lora_local_path") + ) + lora_scale = lora_body.get("scale") + if lora_scale is None: + lora_scale = lora_body.get("lora_scale") + lora_int_id = lora_body.get("int_id") + if lora_int_id is None: + lora_int_id = lora_body.get("lora_int_id") + if lora_int_id is None and lora_path: + lora_int_id = stable_lora_int_id(str(lora_path)) + + if not lora_name or not lora_path: + raise ValueError("Invalid lora object: both name and path are required.") + + scale = float(lora_scale) if lora_scale is not None else None + return LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)), scale diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 99f2042e24..317d54322f 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -1,11 +1,7 @@ from __future__ import annotations -import enum -import importlib -import json import logging import os -from collections.abc import Callable from multiprocessing import shared_memory as _shm from typing import Any @@ -14,38 +10,6 @@ logger = logging.getLogger(__name__) -def load_func_from_config(func_path: str | None) -> Callable[..., Any] | None: - """Dynamically import a callable from a fully-qualified dotted path. - - Args: - func_path: Dotted path such as ``"pkg.module.func_name"``, or *None*. - - Returns: - The imported callable, or *None* when *func_path* is falsy. - """ - if not func_path: - return None - module_path, func_name = func_path.rsplit(".", 1) - module = importlib.import_module(module_path) - return getattr(module, func_name) - - -class OmniStageTaskType(enum.Enum): - GENERATE = "generate" - ABORT = "abort" - SHUTDOWN = "shutdown" - PROFILER_START = "profiler_start" - PROFILER_STOP = "profiler_stop" - COLLECTIVE_RPC = "collective_rpc" - - -SHUTDOWN_TASK = {"type": OmniStageTaskType.SHUTDOWN} - - -def is_profiler_task(task_type: OmniStageTaskType) -> bool: - return task_type in (OmniStageTaskType.PROFILER_START, OmniStageTaskType.PROFILER_STOP) - - def set_stage_devices( stage_id: int, devices: str | int | None, @@ -207,58 +171,6 @@ def shm_read_bytes(meta: dict[str, Any]) -> bytes: return data -def _ensure_parent_dir(path: str) -> None: - """Ensure the parent directory for a file path exists (best-effort).""" - try: - parent = os.path.dirname(path) - if parent: - os.makedirs(parent, exist_ok=True) - except Exception: - pass - - -def append_jsonl(path: str, record: dict[str, Any]) -> None: - """Append a JSON record as one line to a JSONL file (best-effort). - - This is safe to call from multiple processes when each process writes - to a distinct file. For concurrent writes to the same file, OS append - semantics typically suffice, but no additional locking is provided. - """ - try: - _ensure_parent_dir(path) - line = json.dumps(record, ensure_ascii=False) - fd = os.open(path, os.O_APPEND | os.O_CREAT | os.O_WRONLY, 0o644) - with os.fdopen(fd, "a", encoding="utf-8") as f: - f.write(line + "\n") - except Exception: - logger.exception("Failed to append JSONL to %s", path) - - -def maybe_dump_to_shm(obj: Any, threshold: int) -> tuple[bool, Any]: - """Dump object to SHM if serialized size exceeds threshold. - - Returns (True, meta) when dumped; otherwise (False, original_obj). - """ - payload = serialize_obj(obj) - if len(payload) > threshold: - logger.debug(f"Dumping object to SHM with size: {len(payload)}") - return True, shm_write_bytes(payload, name=None) - return False, obj - - -def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) -> Any: - """Load object from container that may carry SHM or inline object. - - Deprecated: prefer `maybe_load_from_ipc_with_metrics` to also obtain - decode-time and size metrics. - """ - if shm_key in container: - from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer - - return OmniSerializer.deserialize(shm_read_bytes(container[shm_key])) - return container[obj_key] - - def maybe_load_from_ipc_with_metrics( container: dict[str, Any], obj_key: str, shm_key: str ) -> tuple[Any, dict[str, float]]: @@ -295,21 +207,6 @@ def maybe_load_from_ipc_with_metrics( } -def encode_for_ipc(obj: Any, threshold: int, obj_key: str, shm_key: str) -> dict[str, Any]: - """Return a dict payload for IPC: inline (obj_key) or SHM (shm_key). - - When serialized size exceeds threshold, returns {shm_key: {name,size}}; - otherwise returns {obj_key: obj}. - """ - payload: dict[str, Any] = {} - use_shm, data = maybe_dump_to_shm(obj, threshold) - if use_shm: - payload[shm_key] = data - else: - payload[obj_key] = data - return payload - - # Convert OmegaConf/objects to plain dicts def _to_dict(x: Any) -> dict[str, Any]: try: @@ -321,58 +218,3 @@ def _to_dict(x: Any) -> dict[str, Any]: return dict(x) except Exception: return {} - - -def _resolve_model_to_local_path(model: str) -> str: - """Resolve an HF Hub model ID to its local cache snapshot path.""" - if os.path.isdir(model): - return model - - try: - from huggingface_hub import snapshot_download - - # no network access is attempted, check local model path only - return snapshot_download(model, local_files_only=True) - except Exception: - logger.warning(f"Could not resolve {model} to a local snapshot path; using as-is", exc_info=True) - return model - - -def _resolve_model_tokenizer_paths( - model: str, - engine_args: dict, -) -> str: - """Resolve model and tokenizer paths for non-standard directory structures. - - Some models (e.g., GLM-Image) have tokenizer in root and model in subdirectory. - This function handles model_subdir and tokenizer_subdir engine_args. - - When the base model path is an HF Hub ID rather than an absolute local path, - the ID is first resolved to the local snapshot directory so that subdirectory - joins produce valid filesystem paths. - - Args: - model: Base model path or HF Hub model ID - engine_args: Engine arguments (modified in-place to remove subdir args - and set tokenizer if needed) - - Returns: - Resolved model path (may be subdirectory of original) - """ - model_subdir = engine_args.pop("model_subdir", None) - tokenizer_subdir = engine_args.pop("tokenizer_subdir", None) - resolved_base = _resolve_model_to_local_path(model) - - if model_subdir: - model = os.path.join(resolved_base, model_subdir) - logger.info(f"Using model subdirectory: {model}") - - if tokenizer_subdir is not None: - tokenizer_path = os.path.join(resolved_base, tokenizer_subdir) if tokenizer_subdir else resolved_base - engine_args["tokenizer"] = tokenizer_path - logger.info(f"Using tokenizer from: {tokenizer_path}") - elif model_subdir and "tokenizer" not in engine_args: - engine_args["tokenizer"] = resolved_base - logger.info(f"Using tokenizer from base model path: {resolved_base}") - - return model diff --git a/vllm_omni/entrypoints/zmq_utils.py b/vllm_omni/entrypoints/zmq_utils.py deleted file mode 100644 index 2ef5685cda..0000000000 --- a/vllm_omni/entrypoints/zmq_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -"""ZMQ-based queue utilities for Omni IPC.""" - -from __future__ import annotations - -import queue -from typing import Any - -import zmq -from vllm.utils.network_utils import make_zmq_socket - - -class ZmqQueue: - """Queue-like wrapper on a ZMQ socket.""" - - def __init__( - self, - ctx: zmq.Context, - socket_type: int, - *, - bind: str | None = None, - connect: str | None = None, - recv_timeout_ms: int | None = None, - send_timeout_ms: int | None = None, - ) -> None: - # Determine path and bind mode - path = bind if bind is not None else connect - if path is None: - raise ValueError("Either bind or connect must be specified") - bind_mode = bind is not None - - self._socket = make_zmq_socket(ctx, path, socket_type, bind=bind_mode, linger=5000) - - # Reusable poller for efficient polling operations - self._poller = zmq.Poller() - self._poller.register(self._socket, zmq.POLLIN) - - # Store default timeout settings - self._default_recv_timeout = recv_timeout_ms - self._default_send_timeout = send_timeout_ms - - # Apply timeout settings if specified - if recv_timeout_ms is not None: - self._socket.rcvtimeo = recv_timeout_ms - if send_timeout_ms is not None: - self._socket.sndtimeo = send_timeout_ms - - self.endpoint = path - - def put(self, obj: Any) -> None: - """Send an object to the queue. Blocks until sent or timeout.""" - try: - self._socket.send_pyobj(obj) - except zmq.Again as e: - raise queue.Full() from e - - def put_nowait(self, obj: Any) -> None: - """Send an object to the queue without blocking.""" - try: - self._socket.send_pyobj(obj, flags=zmq.NOBLOCK) - except zmq.Again as e: - raise queue.Full() from e - - def get(self, timeout: float | None = None) -> Any: - """Receive an object from the queue with optional timeout in seconds.""" - if timeout is None: - return self._socket.recv_pyobj() - - # Use the reusable poller for timeout handling - events = dict(self._poller.poll(int(timeout * 1000))) - if events.get(self._socket) == zmq.POLLIN: - return self._socket.recv_pyobj() - raise queue.Empty() - - def get_nowait(self) -> Any: - """Receive an object from the queue without blocking.""" - try: - return self._socket.recv_pyobj(flags=zmq.NOBLOCK) - except zmq.Again as e: - raise queue.Empty() from e - - def empty(self) -> bool: - """Check if the queue is empty without blocking.""" - events = dict(self._poller.poll(0)) - return events.get(self._socket) != zmq.POLLIN - - def close(self) -> None: - self._socket.close(0) - - -def create_zmq_queue(ctx: zmq.Context, endpoint: str, socket_type: int) -> ZmqQueue: - """Create a ZmqQueue from an endpoint string and socket type.""" - return ZmqQueue(ctx, socket_type, connect=endpoint) diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml b/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml new file mode 100644 index 0000000000..75e9dc1b19 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml @@ -0,0 +1,92 @@ +model_type: qwen3_tts +async_chunk: true + +stages: + - stage_id: 0 + model_stage: qwen3_tts + stage_type: llm + is_comprehension: true + input_sources: [] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + runtime: + devices: "0" + max_batch_size: 10 + engine_args: + model_arch: Qwen3TTSTalkerForConditionalGeneration + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGeneration] + enforce_eager: false + trust_remote_code: true + async_scheduling: true + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.08 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + model_stage: code2wav + stage_type: llm + input_sources: [0] + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + final_output: true + final_output_type: audio + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_arch: Qwen3TTSCode2Wav + hf_overrides: + architectures: [Qwen3TTSCode2Wav] + enforce_eager: true + trust_remote_code: true + async_scheduling: true + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.08 + distributed_executor_backend: "mp" + max_num_batched_tokens: 8192 + max_model_len: 32768 + input_connectors: + from_stage_0: connector_of_shared_memory + tts_args: + max_instructions_length: 500 + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + codec_streaming: true + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + codec_chunk_frames: 25 + codec_left_context_frames: 25 + +edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index 72597249cd..ba44c378bb 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -938,7 +938,7 @@ def _scan(obj: object, depth: int = 0) -> None: wav_candidates.append(obj) return if isinstance(obj, dict): - # Inlined ndarray/tensor payloads from OmniInputProcessor. + # Inlined ndarray/tensor payloads from the input processor. if obj.get("__ndarray__") and "data" in obj and "dtype" in obj and "shape" in obj: try: data = obj["data"] diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml index e8a603af44..119047fb15 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml @@ -1,10 +1,10 @@ -# Stage config for running Hunyuan-Image3.0 with architecture of OmniLLM. +# Stage config for running Hunyuan-Image3.0 for multi-stage omni runtime. # Stage 0: AR Model (vLLM implementation) # The following config has been verified on 8x L40S-48G GPU. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true # Run this stage in a separate process devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) diff --git a/vllm_omni/model_executor/stage_configs/mimo_audio.yaml b/vllm_omni/model_executor/stage_configs/mimo_audio.yaml index b824fcb41f..0807b68ca4 100644 --- a/vllm_omni/model_executor/stage_configs/mimo_audio.yaml +++ b/vllm_omni/model_executor/stage_configs/mimo_audio.yaml @@ -1,4 +1,4 @@ -# stage config for running mimo-audio with architecture of OmniLLM. +# stage config for running mimo-audio for multi-stage omni runtime. # The following config has been verified on 1x H20-96G GPU. async_chunk: false diff --git a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml index 7177aa8092..ef9e1dedb0 100644 --- a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml +++ b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml @@ -1,4 +1,4 @@ -# stage config for running mimo-audio with architecture of OmniLLM. +# stage config for running mimo-audio for multi-stage omni runtime. # The following config has been verified on 1x H20-96G GPU. async_chunk: true diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml index 3c05cffb72..c66d0fcd4e 100644 --- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml @@ -1,9 +1,9 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. # The following config has been verified on 2x H100-80G GPU. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true # Run this stage in a separate process devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) @@ -33,7 +33,7 @@ stage_args: repetition_penalty: 1.1 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "1" @@ -62,7 +62,7 @@ stage_args: stop_token_ids: [8294] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml index 5e379aa6b7..927d71b495 100644 --- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml @@ -1,9 +1,9 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. # The following config has been verified on 1x H100-80G GPU. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true # Run this stage in a separate process devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) @@ -34,7 +34,7 @@ stage_args: output_connectors: to_stage_1: mooncake_connector - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "1" @@ -66,7 +66,7 @@ stage_args: output_connectors: to_stage_2: mooncake_connector - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml index eb5bf71012..759590bfc7 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml @@ -7,7 +7,7 @@ async_chunk: false stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "0" max_batch_size: 64 @@ -38,7 +38,7 @@ stage_args: repetition_penalty: 1.05 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "1" max_batch_size: 64 @@ -69,7 +69,7 @@ stage_args: stop_token_ids: [2150] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "1" max_batch_size: 32 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml index 8d8d00d6bc..d38302f8e0 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml @@ -7,7 +7,7 @@ async_chunk: true stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "0" max_batch_size: 64 @@ -42,7 +42,7 @@ stage_args: repetition_penalty: 1.05 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "1" max_batch_size: 64 @@ -76,7 +76,7 @@ stage_args: stop_token_ids: [2150] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "1" max_batch_size: 64 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml index a52b20fcc4..b91937b0cb 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml @@ -6,7 +6,7 @@ # The following config has been verified on 2x H100-80G GPUs. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "0" max_batch_size: 1 @@ -39,7 +39,7 @@ stage_args: to_stage_1: mooncake_connector - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "1" max_batch_size: 1 @@ -75,7 +75,7 @@ stage_args: to_stage_2: mooncake_connector - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "1" max_batch_size: 64 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml index 570ed710c5..971c0de125 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -2,6 +2,7 @@ async_chunk: true stage_args: - stage_id: 0 stage_type: llm + is_comprehension: true runtime: devices: "0" max_batch_size: 10 @@ -53,7 +54,7 @@ stage_args: async_scheduling: true enable_prefix_caching: false engine_output_type: audio - gpu_memory_utilization: 0.2 + gpu_memory_utilization: 0.3 distributed_executor_backend: "mp" # Must be divisible by num_code_groups and cover (left_context + chunk). max_num_batched_tokens: 8192 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml index 5236668dc6..7770b9e7c2 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml @@ -6,6 +6,7 @@ async_chunk: true stage_args: - stage_id: 0 stage_type: llm + is_comprehension: true runtime: devices: "0" max_batch_size: 4 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml index da22953bba..b4ae91c9ab 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml @@ -2,6 +2,7 @@ async_chunk: false stage_args: - stage_id: 0 stage_type: llm + is_comprehension: true runtime: devices: "0" max_batch_size: 1 diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 6f035a81ec..be65ec98f7 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -135,14 +135,12 @@ def multimodal_output(self) -> dict[str, Any]: if self.request_output is None: return self._multimodal_output - request_outputs = self.request_output if isinstance(self.request_output, list) else [self.request_output] - for req_out in request_outputs: - # Check completion outputs first (where multimodal_output is attached) - for output in getattr(req_out, "outputs", []): - if mm := getattr(output, "multimodal_output", None): - return mm - if mm := getattr(req_out, "multimodal_output", None): + # Check completion outputs first (where multimodal_output is attached) + for output in getattr(self.request_output, "outputs", []): + if mm := getattr(output, "multimodal_output", None): return mm + if mm := getattr(self.request_output, "multimodal_output", None): + return mm return self._multimodal_output @property diff --git a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml index d021ef218e..02a6b0dd24 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml @@ -1,7 +1,7 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true # Run this stage in a separate process devices: "0" # Visible devices for this stage @@ -28,7 +28,7 @@ stage_args: detokenize: True repetition_penalty: 1.1 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "1" @@ -55,7 +55,7 @@ stage_args: repetition_penalty: 1.05 stop_token_ids: [8294] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "2" # Example: use a different NPU than the previous stage; use "0" if single NPU diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml index 4ede584b59..27f0ffcc29 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml @@ -7,7 +7,7 @@ async_chunk: true stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "0,1" max_batch_size: 10 @@ -39,7 +39,7 @@ stage_args: repetition_penalty: 1.05 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "2" max_batch_size: 10 @@ -70,7 +70,7 @@ stage_args: stop_token_ids: [2150] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "2" max_batch_size: 10 diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml index 60659a9768..53185e4532 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml @@ -2,6 +2,7 @@ async_chunk: true stage_args: - stage_id: 0 stage_type: llm + is_comprehension: true runtime: devices: "0" max_batch_size: 1 diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml index 7887cd2bb0..8c4ffaa6fe 100644 --- a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml +++ b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml @@ -1,4 +1,4 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. # The following config has been verified on 2x H100-80G GPU. stage_args: diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml index 0e33d4892d..221b656027 100644 --- a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml +++ b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml @@ -1,9 +1,9 @@ -# stage config for running qwen2.5-omni with architecture of OmniLLM. +# stage config for running qwen2.5-omni for multi-stage omni runtime. # The following config is verified with 2 * Intel Arc Pro B60 XPU. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true # Run this stage in a separate process devices: "0" # Visible devices for this stage @@ -30,7 +30,7 @@ stage_args: detokenize: True repetition_penalty: 1.1 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "1" @@ -58,7 +58,7 @@ stage_args: stop_token_ids: [8294] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: process: true devices: "1" diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml index cdf0da9053..59fed7c786 100644 --- a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml +++ b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml @@ -6,7 +6,7 @@ # The following config is verified with 8 * Intel Arc Pro B60 XPU. stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "0,1,2,3" max_batch_size: 1 @@ -38,7 +38,7 @@ stage_args: repetition_penalty: 1.05 - stage_id: 1 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "4" max_batch_size: 1 @@ -70,7 +70,7 @@ stage_args: stop_token_ids: [2150] - stage_id: 2 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm # Use llm stage type for AR stages runtime: devices: "4" max_batch_size: 1 diff --git a/vllm_omni/transformers_utils/configs/mammoth_moda2.py b/vllm_omni/transformers_utils/configs/mammoth_moda2.py index e8a4bd6dd4..4acd098e21 100644 --- a/vllm_omni/transformers_utils/configs/mammoth_moda2.py +++ b/vllm_omni/transformers_utils/configs/mammoth_moda2.py @@ -127,7 +127,7 @@ def __init__( else: self.gen_vocab_start_index = gen_vocab_start_index - # NOTE: vLLM V1 uses `hf_text_config.vocab_size` for sampling parameter validation + # NOTE: vLLM uses `hf_text_config.vocab_size` for sampling parameter validation # (e.g., allowed_token_ids). Although MammothModa2's gen vocab is implemented via # independent gen_embed/gen_head, the overall vocab size should still cover the # gen vocab token ID range from the perspective of "output logits dimension".