From 203ec97986f1cec5d37857434da6dfeb9e3f1573 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B8=B8=EC=9E=AC=EC=9D=80?= Date: Mon, 6 Apr 2026 09:18:02 +0900 Subject: [PATCH 1/4] feat: add HyperCLOVAX-SEED-Omni-8B vision pipeline, thinker, and stage config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - diffusion/models/hyperclovax_vision/: HyperCLOVAX vision diffusion pipeline (transformer, layers, vision_token_embedder, pipeline) - model_executor/models/hcx_omni/: HCX Omni thinker model - model_executor/stage_configs/hcx_omni.yaml: 3-stage pipeline config (Stage-0 LLM thinker, Stage-1 vision decoder, Stage-2 audio decoder) - model_executor/stage_input_processors/hyperclovax_seed_omni.py: thinker→vision/audio token routing - engine/, entrypoints/: arg_utils, input_processor, omni_llm, zmq_utils, stage_utils, cli/main integration - examples/online_serving/hcx_omni/: client demo and run script - tests/: e2e and unit tests for HCX Omni Co-Authored-By: Hyunjoon Jeong --- examples/online_serving/hcx_omni/README.md | 107 ++++ .../online_serving/hcx_omni/client_demo.py | 153 +++++ .../online_serving/hcx_omni/run_server.sh | 52 ++ tests/e2e/offline_inference/test_hcx_omni.py | 169 ++++++ tests/e2e/online_serving/test_hcx_omni.py | 129 ++++ tests/e2e/stage_configs/hcx_omni_ci.yaml | 93 +++ tests/unit/__init__.py | 0 tests/unit/model_executor/__init__.py | 0 .../test_hcx_omni_processing.py | 187 ++++++ .../models/hyperclovax_vision/__init__.py | 21 + .../hyperclovax_vision_transformer.py | 146 +++++ .../models/hyperclovax_vision/layers.py | 229 +++++++ .../pipeline_hyperclovax_vision.py | 429 +++++++++++++ .../hyperclovax_vision/transformer_usp.py | 341 +++++++++++ .../vision_token_embedder.py | 117 ++++ vllm_omni/engine/arg_utils.py | 93 +-- vllm_omni/engine/input_processor.py | 280 +-------- vllm_omni/entrypoints/cli/main.py | 1 + vllm_omni/entrypoints/omni.py | 572 ++++++++++++++---- vllm_omni/entrypoints/omni_diffusion.py | 131 ++-- vllm_omni/entrypoints/omni_llm.py | 60 +- vllm_omni/entrypoints/stage_utils.py | 96 ++- vllm_omni/entrypoints/zmq_utils.py | 95 +++ .../models/hcx_omni/__init__.py | 3 + .../models/hcx_omni/hcx_omni.py | 138 +++++ .../models/hcx_omni/hcx_omni_thinker.py | 130 ++++ vllm_omni/model_executor/models/registry.py | 112 ++++ .../stage_configs/hcx_omni.yaml | 102 ++++ .../hyperclovax_seed_omni.py | 156 +++++ 29 files changed, 3625 insertions(+), 517 deletions(-) create mode 100644 examples/online_serving/hcx_omni/README.md create mode 100644 examples/online_serving/hcx_omni/client_demo.py create mode 100755 examples/online_serving/hcx_omni/run_server.sh create mode 100644 tests/e2e/offline_inference/test_hcx_omni.py create mode 100644 tests/e2e/online_serving/test_hcx_omni.py create mode 100644 tests/e2e/stage_configs/hcx_omni_ci.yaml create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/model_executor/__init__.py create mode 100644 tests/unit/model_executor/test_hcx_omni_processing.py create mode 100644 vllm_omni/diffusion/models/hyperclovax_vision/__init__.py create mode 100644 vllm_omni/diffusion/models/hyperclovax_vision/hyperclovax_vision_transformer.py create mode 100644 vllm_omni/diffusion/models/hyperclovax_vision/layers.py create mode 100644 vllm_omni/diffusion/models/hyperclovax_vision/pipeline_hyperclovax_vision.py create mode 100644 vllm_omni/diffusion/models/hyperclovax_vision/transformer_usp.py create mode 100644 vllm_omni/diffusion/models/hyperclovax_vision/vision_token_embedder.py create mode 100644 vllm_omni/entrypoints/zmq_utils.py create mode 100644 vllm_omni/model_executor/models/hcx_omni/__init__.py create mode 100644 vllm_omni/model_executor/models/hcx_omni/hcx_omni.py create mode 100644 vllm_omni/model_executor/models/hcx_omni/hcx_omni_thinker.py create mode 100644 vllm_omni/model_executor/stage_configs/hcx_omni.yaml create mode 100644 vllm_omni/model_executor/stage_input_processors/hyperclovax_seed_omni.py diff --git a/examples/online_serving/hcx_omni/README.md b/examples/online_serving/hcx_omni/README.md new file mode 100644 index 00000000000..8ebce61deb4 --- /dev/null +++ b/examples/online_serving/hcx_omni/README.md @@ -0,0 +1,107 @@ +# HyperCLOVAX-SEED-Omni-8B with vLLM-Omni + +[HyperCLOVAX-SEED-Omni-8B](https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B) +is an omni-modal model by NAVER Cloud that supports: + +| Input | Output | +|--------|-----------------| +| Text | Text | +| Audio | Text + Audio | +| Image | Text | +| Text | Text + Image | +| Audio | Text + Audio + Image | + +## Architecture + +The model uses a 3-stage pipeline: + +``` +Stage 0 (Thinker) ──→ Stage 1 (Vision Decoder, diffusion) + │ + └──────────→ Stage 2 (Audio Decoder, unit-BigVGAN) +``` + +- **Thinker**: Qwen2.5-VL vision encoder + Qwen2Audio encoder + HyperCLOVAX language model. + Outputs text tokens and discrete audio/vision codes in the vocabulary. +- **Vision Decoder**: Diffusion-based image generation from 729 discrete TA-Tok codes. +- **Audio Decoder**: Unit-BigVGAN vocoder from CosyVoice2 FSQ discrete audio codes. + +## Hardware Requirements + +| Setup | GPUs | +|-----------|---------------------------------------------| +| Default | 6 × GPU ≥24 GB (4 for thinker TP, 1+1 for decoders) | +| Minimal | 3 × GPU ≥24 GB (1 for thinker, 1+1 for decoders) | + +## Quick Start + +### 1. Start the Server + +```bash +# 6-GPU setup (production) +./run_server.sh --model naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B + +# Custom GPU allocation +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 ./run_server.sh +``` + +### 2. Run the Client Demo + +```bash +# All modes: text-only, text-to-vision, speech-to-speech +python client_demo.py --base-url http://localhost:8000/v1 + +# Speech-to-Speech with your own audio file +python client_demo.py --mode s2s --audio-file /path/to/speech.wav + +# Text-to-Vision +python client_demo.py --mode t2v --prompt "고양이 그림을 그려줘" +``` + +### 3. Use the OpenAI API Directly + +**Speech-to-Speech:** +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B", + "modalities": ["text", "audio"], + "messages": [{ + "role": "user", + "content": [ + {"type": "input_audio", "input_audio": {"data": "", "format": "wav"}}, + {"type": "text", "text": "이 오디오에 무슨 내용이 있나요?"} + ] + }] + }' +``` + +**Text-to-Vision:** +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B", + "modalities": ["text", "image"], + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "귀여운 강아지 한 마리가 공원에서 뛰노는 그림을 그려줘."} + ] + }] + }' +``` + +## Stage Config + +The default stage config is at +`vllm_omni/model_executor/stage_configs/hcx_omni.yaml`. + +Key parameters: + +| Stage | Type | `model_arch` / `model_class_name` | GPU | +|-------|-----------|------------------------------------|-------| +| 0 | LLM | `HCXVisionV2ForCausalLM` | 0-3 | +| 1 | Diffusion | `HyperCLOVAXVisionPipeline` | 4 | +| 2 | Diffusion | `HyperCLOVAXAudioPipeline` | 5 | diff --git a/examples/online_serving/hcx_omni/client_demo.py b/examples/online_serving/hcx_omni/client_demo.py new file mode 100644 index 00000000000..b07529f0a30 --- /dev/null +++ b/examples/online_serving/hcx_omni/client_demo.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""HyperCLOVAX-SEED-Omni-8B client demo. + +Demonstrates Speech-to-Speech and Text-to-Vision via the OpenAI-compatible +HTTP API provided by vLLM-Omni. + +Usage: + # Start the server first (see run_server.sh), then: + python client_demo.py --base-url http://localhost:8000/v1 + + # With a local audio file: + python client_demo.py --audio-file path/to/speech.wav + + # Text-to-Vision only: + python client_demo.py --mode t2v --prompt "고양이 그림을 그려줘" +""" +import argparse +import base64 +import io +import sys +from pathlib import Path + +try: + from openai import OpenAI +except ImportError: + print("Please install openai: pip install openai") + sys.exit(1) + + +def encode_audio_file(path: str) -> str: + """Base64-encode a WAV/MP3 file.""" + with open(path, "rb") as f: + return base64.b64encode(f.read()).decode() + + +def encode_audio_array(array, sample_rate: int = 16000) -> str: + """Base64-encode a numpy audio array as WAV.""" + import numpy as np + import scipy.io.wavfile as wav + + if not isinstance(array, np.ndarray): + array = np.array(array) + buf = io.BytesIO() + wav.write(buf, sample_rate, (array * 32767).astype(np.int16)) + return base64.b64encode(buf.getvalue()).decode() + + +def speech_to_speech(client: OpenAI, audio_b64: str, prompt: str = "이 오디오에 무슨 내용이 있나요?"): + """Send audio → receive text + audio.""" + print(f"\n[Speech-to-Speech] prompt: {prompt!r}") + response = client.chat.completions.create( + model="naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B", + modalities=["text", "audio"], + messages=[ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": {"data": audio_b64, "format": "wav"}, + }, + {"type": "text", "text": prompt}, + ], + } + ], + ) + choice = response.choices[0] + print(f"Text response: {choice.message.content}") + if hasattr(choice.message, "audio") and choice.message.audio: + audio_data = base64.b64decode(choice.message.audio.data) + out_path = Path("/tmp/hcx_omni_response.wav") + out_path.write_bytes(audio_data) + print(f"Audio saved to: {out_path}") + return response + + +def text_to_vision(client: OpenAI, prompt: str = "귀여운 강아지 한 마리가 공원에서 뛰노는 그림을 그려줘."): + """Send text → receive text + image.""" + print(f"\n[Text-to-Vision] prompt: {prompt!r}") + response = client.chat.completions.create( + model="naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B", + modalities=["text", "image"], + messages=[ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ], + ) + choice = response.choices[0] + print(f"Text response: {choice.message.content}") + if hasattr(choice.message, "image") and choice.message.image: + img_data = base64.b64decode(choice.message.image.data) + out_path = Path("/tmp/hcx_omni_generated.png") + out_path.write_bytes(img_data) + print(f"Image saved to: {out_path}") + return response + + +def text_only(client: OpenAI, prompt: str = "대한민국의 수도는 어디인가요?"): + """Pure text conversation (thinker only).""" + print(f"\n[Text-only] prompt: {prompt!r}") + response = client.chat.completions.create( + model="naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B", + modalities=["text"], + messages=[ + {"role": "user", "content": prompt} + ], + ) + print(f"Response: {response.choices[0].message.content}") + return response + + +def main(): + parser = argparse.ArgumentParser(description="HyperCLOVAX-SEED-Omni-8B demo") + parser.add_argument("--base-url", default="http://localhost:8000/v1") + parser.add_argument( + "--mode", + choices=["s2s", "t2v", "text", "all"], + default="all", + help="Demo mode: s2s=Speech-to-Speech, t2v=Text-to-Vision, text=Text-only", + ) + parser.add_argument("--audio-file", default=None, help="Path to input audio file") + parser.add_argument("--prompt", default=None, help="Text prompt override") + args = parser.parse_args() + + client = OpenAI(api_key="EMPTY", base_url=args.base_url) + + if args.mode in ("text", "all"): + text_only(client, prompt=args.prompt or "대한민국의 수도는 어디인가요?") + + if args.mode in ("t2v", "all"): + text_to_vision(client, prompt=args.prompt or "귀여운 강아지 한 마리가 공원에서 뛰노는 그림을 그려줘.") + + if args.mode in ("s2s", "all"): + if args.audio_file: + audio_b64 = encode_audio_file(args.audio_file) + else: + # Generate synthetic 1-second sine wave + try: + import numpy as np + t = np.linspace(0, 1, 16000, endpoint=False) + audio_array = np.sin(2 * np.pi * 440 * t).astype(np.float32) + audio_b64 = encode_audio_array(audio_array) + except ImportError: + print("numpy not available, skipping S2S demo") + return + speech_to_speech(client, audio_b64, prompt=args.prompt or "이 오디오에 무슨 내용이 있나요?") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/hcx_omni/run_server.sh b/examples/online_serving/hcx_omni/run_server.sh new file mode 100755 index 00000000000..c3cbafba4b2 --- /dev/null +++ b/examples/online_serving/hcx_omni/run_server.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Launch HyperCLOVAX-SEED-Omni-8B with vLLM-Omni. +# +# Requirements: +# - 6× GPUs (≥24 GB VRAM each): +# GPU 0-3: Thinker (tensor_parallel_size=4) +# GPU 4 : Vision decoder +# GPU 5 : Audio decoder +# - HF model: naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B +# +# Usage: +# ./run_server.sh [--model MODEL] [--port PORT] [--stage-configs-path PATH] + +set -e + +MODEL="${MODEL:-naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B}" +PORT="${PORT:-8000}" +HOST="${HOST:-0.0.0.0}" +STAGE_CONFIG="${STAGE_CONFIG:-}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DEFAULT_STAGE_CONFIG="$SCRIPT_DIR/../../../vllm_omni/model_executor/stage_configs/hcx_omni.yaml" + +while [[ $# -gt 0 ]]; do + case $1 in + --model) MODEL="$2"; shift 2 ;; + --port) PORT="$2"; shift 2 ;; + --host) HOST="$2"; shift 2 ;; + --stage-configs-path) STAGE_CONFIG="$2"; shift 2 ;; + --help) + echo "Usage: $0 [--model MODEL] [--port PORT] [--host HOST] [--stage-configs-path PATH]" + exit 0 ;; + *) echo "Unknown: $1"; exit 1 ;; + esac +done + +[[ -z "$STAGE_CONFIG" ]] && STAGE_CONFIG="$DEFAULT_STAGE_CONFIG" + +echo "=================================================" +echo " HyperCLOVAX-SEED-Omni-8B vLLM-Omni Server" +echo "=================================================" +echo " Model : $MODEL" +echo " Stage config: $STAGE_CONFIG" +echo " Endpoint : http://$HOST:$PORT/v1" +echo "=================================================" + +python -m vllm_omni.entrypoints.openai.api_server \ + --model "$MODEL" \ + --stage-configs-path "$STAGE_CONFIG" \ + --port "$PORT" \ + --host "$HOST" \ + --trust-remote-code diff --git a/tests/e2e/offline_inference/test_hcx_omni.py b/tests/e2e/offline_inference/test_hcx_omni.py new file mode 100644 index 00000000000..9c0dba9df5e --- /dev/null +++ b/tests/e2e/offline_inference/test_hcx_omni.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""E2E tests for HyperCLOVAX-SEED-Omni-8B. + +Tests cover: + - Text-only inference (comprehension) + - Speech-to-Speech (audio input → audio output) + - Text-to-Vision (text input → image output) + - Audio-to-Vision (audio input → image + audio output) +""" +from pathlib import Path + +import pytest + +from tests.conftest import ( + generate_synthetic_audio, + generate_synthetic_image, + modify_stage_config, +) +from tests.utils import hardware_test +from vllm_omni.platforms import current_omni_platform + +MODEL = "naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B" + +_CI_YAML = str( + Path(__file__).parent.parent / "stage_configs" / "hcx_omni_ci.yaml" +) + + +def _ci_config(enforce_eager: bool = True) -> str: + updates: dict = { + "stage_args": { + 0: {"engine_args.enforce_eager": str(enforce_eager).lower()}, + } + } + return modify_stage_config(_CI_YAML, updates=updates) + + +stage_config = _ci_config(enforce_eager=True) +test_params = [(MODEL, stage_config)] + + +# ------------------------------------------------------------------ # +# Helper # +# ------------------------------------------------------------------ # + +def _text_prompt(text: str) -> dict: + return { + "role": "user", + "content": [{"type": "text", "text": text}], + } + + +def _audio_text_prompt(audio_array, text: str) -> dict: + return { + "role": "user", + "content": [ + {"type": "input_audio", "input_audio": {"data": audio_array, "format": "wav"}}, + {"type": "text", "text": text}, + ], + } + + +def _image_text_prompt(image_array, text: str) -> dict: + return { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_array}}, + {"type": "text", "text": text}, + ], + } + + +# ------------------------------------------------------------------ # +# Tests # +# ------------------------------------------------------------------ # + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4"}, + num_cards={"cuda": 3}, +) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_text_to_text(omni_runner, omni_runner_handler) -> None: + """Text-only (comprehension) request — verifies thinker stage alone.""" + request_config = { + "prompts": "What is the capital of South Korea?", + "output_modalities": ["text"], + } + results = omni_runner.run(request_config) + assert results and len(results) > 0 + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4"}, + num_cards={"cuda": 3}, +) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_audio_to_audio(omni_runner, omni_runner_handler) -> None: + """Speech-to-Speech: audio input processed by thinker → audio decoder.""" + audio = generate_synthetic_audio(1, 1, 16000)["np_array"] + if len(audio.shape) == 2: + audio = audio.squeeze() + + request_config = { + "prompts": _audio_text_prompt(audio, "Repeat what you heard."), + "output_modalities": ["text", "audio"], + } + results = omni_runner.run(request_config) + assert results and len(results) > 0 + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4"}, + num_cards={"cuda": 3}, +) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_text_to_image(omni_runner, omni_runner_handler) -> None: + """Text-to-Vision: text prompt → image generated by vision decoder.""" + request_config = { + "prompts": "Draw a picture of a cat sitting on a sofa.", + "output_modalities": ["text", "image"], + } + results = omni_runner.run(request_config) + assert results and len(results) > 0 + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4"}, + num_cards={"cuda": 3}, +) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_image_to_text(omni_runner, omni_runner_handler) -> None: + """Image understanding: image input → text description.""" + image = generate_synthetic_image(224, 224)["np_array"] + request_config = { + "prompts": _image_text_prompt(image, "Describe this image."), + "output_modalities": ["text"], + } + results = omni_runner.run(request_config) + assert results and len(results) > 0 + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test( + res={"cuda": "L4"}, + num_cards={"cuda": 3}, +) +@pytest.mark.parametrize("omni_runner", test_params, indirect=True) +def test_multimodal_to_multimodal(omni_runner, omni_runner_handler) -> None: + """Full omni: audio + image input → text + audio + image output.""" + audio = generate_synthetic_audio(1, 1, 16000)["np_array"] + if len(audio.shape) == 2: + audio = audio.squeeze() + + request_config = { + "prompts": "Listen to the audio and draw what you hear.", + "output_modalities": ["text", "audio", "image"], + } + results = omni_runner.run(request_config) + assert results and len(results) > 0 diff --git a/tests/e2e/online_serving/test_hcx_omni.py b/tests/e2e/online_serving/test_hcx_omni.py new file mode 100644 index 00000000000..4b33d9c2416 --- /dev/null +++ b/tests/e2e/online_serving/test_hcx_omni.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""E2E online serving tests for HyperCLOVAX-SEED-Omni-8B. + +Tests the OpenAI-compatible HTTP API for Speech-to-Speech and +Text-to-Vision generation. +""" +import os +from pathlib import Path + +import pytest + +from tests.conftest import ( + OmniServerParams, + generate_synthetic_audio, + generate_synthetic_image, + modify_stage_config, +) +from tests.utils import hardware_test + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +MODEL = "naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B" +_CI_YAML = str( + Path(__file__).parent.parent / "stage_configs" / "hcx_omni_ci.yaml" +) + +test_params = [ + OmniServerParams(model=MODEL, stage_config_path=_CI_YAML) +] + +SYSTEM_PROMPT = { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "당신은 CLOVA X입니다. 네이버가 만든 AI 어시스턴트로서 " + "오디오와 이미지를 인식하고 텍스트, 음성, 이미지를 생성할 수 있습니다." + ), + } + ], +} + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 3}) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_speech_to_speech(omni_server, omni_server_handler) -> None: + """Speech-to-Speech: audio input → text + audio response.""" + audio = generate_synthetic_audio(1, 1, 16000)["np_array"] + if len(audio.shape) == 2: + audio = audio.squeeze() + + messages = [ + SYSTEM_PROMPT, + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": {"data": audio, "format": "wav"}, + }, + {"type": "text", "text": "이 오디오에서 무슨 내용이 들리나요?"}, + ], + }, + ] + request_config = { + "messages": messages, + "modalities": ["text", "audio"], + "stream": False, + } + response = omni_server.chat(request_config) + assert response is not None + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 3}) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_text_to_vision(omni_server, omni_server_handler) -> None: + """Text-to-Vision: text prompt → text + image response.""" + messages = [ + SYSTEM_PROMPT, + { + "role": "user", + "content": [ + {"type": "text", "text": "고양이 한 마리가 소파에 앉아 있는 그림을 그려줘."}, + ], + }, + ] + request_config = { + "messages": messages, + "modalities": ["text", "image"], + "stream": False, + } + response = omni_server.chat(request_config) + assert response is not None + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards={"cuda": 3}) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_image_understanding(omni_server, omni_server_handler) -> None: + """Image understanding: image input → text description.""" + image = generate_synthetic_image(224, 224)["np_array"] + messages = [ + SYSTEM_PROMPT, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image}"}, + }, + {"type": "text", "text": "이 이미지에 무엇이 있나요?"}, + ], + }, + ] + request_config = { + "messages": messages, + "modalities": ["text"], + "stream": False, + } + response = omni_server.chat(request_config) + assert response is not None diff --git a/tests/e2e/stage_configs/hcx_omni_ci.yaml b/tests/e2e/stage_configs/hcx_omni_ci.yaml new file mode 100644 index 00000000000..f455422689e --- /dev/null +++ b/tests/e2e/stage_configs/hcx_omni_ci.yaml @@ -0,0 +1,93 @@ +# Stage config for HyperCLOVAX-SEED-Omni-8B CI tests. +# Verified on 3x 24GB GPU (L4/RTX3090/RTX4090). +# Stage 0 (thinker): 4xTP → single GPU in CI uses 1xTP + +runtime: + connectors: + shared_memory_connector: + extra: + shm_threshold_bytes: 65536 + name: SharedMemoryConnector + defaults: + max_inflight: 1 + window_size: -1 + edges: + - { from: 0, to: 1, window_size: -1 } + - { from: 0, to: 2, window_size: -1 } + enabled: true + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "0" + engine_args: + model_stage: thinker + model_arch: HCXVisionV2ForCausalLM + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + tensor_parallel_size: 1 + max_model_len: 4096 + max_num_batched_tokens: 4096 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + limit_mm_per_prompt: + audio: 1 + image: 1 + load_format: dummy + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.1 + top_p: 1.0 + top_k: -1 + max_tokens: 128 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + + - stage_id: 1 + stage_type: diffusion + runtime: + process: true + devices: "1" + max_batch_size: 1 + engine_args: + engine_output_type: image + gpu_memory_utilization: 0.75 + model_class_name: HyperCLOVAXVisionPipeline + model_stage: decoder/vision + model_subdir: decoder/vision + trust_remote_code: true + enforce_eager: true + engine_input_source: + - 0 + final_output: true + final_output_type: image + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni.thinker2vision_decoder + + - stage_id: 2 + stage_type: diffusion + runtime: + process: true + devices: "2" + max_batch_size: 1 + engine_args: + engine_output_type: audio + gpu_memory_utilization: 0.4 + model_class_name: HyperCLOVAXAudioPipeline + model_stage: decoder/audio + model_subdir: decoder/audio/NCZSCosybigvganDecoder.mar + trust_remote_code: true + enforce_eager: true + engine_input_source: + - 0 + final_output: true + final_output_type: audio + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni.thinker2audio_decoder diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/model_executor/__init__.py b/tests/unit/model_executor/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/model_executor/test_hcx_omni_processing.py b/tests/unit/model_executor/test_hcx_omni_processing.py new file mode 100644 index 00000000000..61ebeb3748c --- /dev/null +++ b/tests/unit/model_executor/test_hcx_omni_processing.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for HCXOmni multimodal token processing. + +Tests verify that: + 1. Audio tokens are correctly positioned and embedded. + 2. Image tokens (continuous path via Qwen2.5-VL) are correctly positioned. + 3. Discrete audio/image token boundaries match config.json values. + 4. Stage input processors correctly extract discrete tokens from mixed output. +""" +import pytest +import torch + +# Token ID boundaries (from HyperCLOVAX-SEED-Omni-8B config.json) +DISCRETE_AUDIO_UNIT_0_ID = 128606 +DISCRETE_IMAGE_UNIT_0_ID = 135168 +DISCRETE_AUDIO_VOCAB_SIZE = 6561 +DISCRETE_IMAGE_VOCAB_SIZE = 65536 +DISCRETE_IMAGE_TOKEN_LENGTH = 729 # 27 * 27 + + +class TestDiscreteTokenBoundaries: + """Verify token ID arithmetic matches config.json.""" + + def test_audio_range(self): + assert DISCRETE_AUDIO_UNIT_0_ID == 128606 + assert DISCRETE_AUDIO_UNIT_0_ID + DISCRETE_AUDIO_VOCAB_SIZE - 1 == 135166 + + def test_image_range(self): + assert DISCRETE_IMAGE_UNIT_0_ID == 135168 + # Image codebook is 2^16 = 65536 + assert DISCRETE_IMAGE_VOCAB_SIZE == 65536 + + def test_no_overlap(self): + audio_end = DISCRETE_AUDIO_UNIT_0_ID + DISCRETE_AUDIO_VOCAB_SIZE + assert audio_end < DISCRETE_IMAGE_UNIT_0_ID, ( + "Audio and image token ranges must not overlap" + ) + + def test_image_token_count_is_square(self): + """TA-Tok produces 27×27 = 729 tokens per image.""" + import math + side = math.isqrt(DISCRETE_IMAGE_TOKEN_LENGTH) + assert side * side == DISCRETE_IMAGE_TOKEN_LENGTH + + +class TestExtractDiscreteTokens: + """Test the _extract_discrete_tokens helper.""" + + def _extract(self, token_ids, start_id, vocab_size): + return [ + tid - start_id + for tid in token_ids + if start_id <= tid < start_id + vocab_size + ] + + def test_extract_audio_tokens(self): + token_ids = [ + 100, 200, # text + DISCRETE_AUDIO_UNIT_0_ID, + DISCRETE_AUDIO_UNIT_0_ID + 42, + DISCRETE_AUDIO_UNIT_0_ID + 100, + 300, # text + ] + result = self._extract(token_ids, DISCRETE_AUDIO_UNIT_0_ID, DISCRETE_AUDIO_VOCAB_SIZE) + assert result == [0, 42, 100] + + def test_extract_image_tokens(self): + token_ids = [ + 100, + DISCRETE_IMAGE_UNIT_0_ID, + DISCRETE_IMAGE_UNIT_0_ID + 255, + 200, + ] + result = self._extract(token_ids, DISCRETE_IMAGE_UNIT_0_ID, DISCRETE_IMAGE_VOCAB_SIZE) + assert result == [0, 255] + + def test_no_overlap_extraction(self): + """Audio extraction must not pick up image tokens and vice versa.""" + mixed = [ + DISCRETE_AUDIO_UNIT_0_ID + 5, + DISCRETE_IMAGE_UNIT_0_ID + 5, + ] + audio = self._extract(mixed, DISCRETE_AUDIO_UNIT_0_ID, DISCRETE_AUDIO_VOCAB_SIZE) + image = self._extract(mixed, DISCRETE_IMAGE_UNIT_0_ID, DISCRETE_IMAGE_VOCAB_SIZE) + assert audio == [5] + assert image == [5] + + def test_truncate_and_pad_image(self): + """Vision decoder needs exactly DISCRETE_IMAGE_TOKEN_LENGTH codes.""" + codes = list(range(DISCRETE_IMAGE_TOKEN_LENGTH + 50)) # too long + truncated = codes[:DISCRETE_IMAGE_TOKEN_LENGTH] + assert len(truncated) == DISCRETE_IMAGE_TOKEN_LENGTH + + codes_short = list(range(100)) # too short + padded = codes_short + [0] * (DISCRETE_IMAGE_TOKEN_LENGTH - len(codes_short)) + assert len(padded) == DISCRETE_IMAGE_TOKEN_LENGTH + + +class TestStageInputProcessor: + """Test thinker2vision_decoder and thinker2audio_decoder processors.""" + + def _make_fake_output(self, token_ids: list[int]): + """Create a minimal fake EngineCoreOutput-like object.""" + from types import SimpleNamespace + output = SimpleNamespace( + token_ids=token_ids, + ) + thinker_out = SimpleNamespace( + outputs=[output], + request_id="test-001", + prompt_token_ids=[1, 2, 3], + ) + return thinker_out + + def test_vision_decoder_extracts_image_tokens(self): + """thinker2vision_decoder should extract exactly 729 image tokens.""" + image_codes = list(range(DISCRETE_IMAGE_UNIT_0_ID, + DISCRETE_IMAGE_UNIT_0_ID + DISCRETE_IMAGE_TOKEN_LENGTH)) + audio_codes = list(range(DISCRETE_AUDIO_UNIT_0_ID, + DISCRETE_AUDIO_UNIT_0_ID + 20)) + token_ids = [100, 200] + audio_codes + image_codes + [300] + + thinker_out = self._make_fake_output(token_ids) + + from types import SimpleNamespace + stage_list = {0: SimpleNamespace(engine_outputs=[thinker_out])} + + from vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni import ( + thinker2vision_decoder, + ) + results = thinker2vision_decoder(stage_list, [0]) + assert len(results) == 1 + prompt_ids = results[0]["prompt_token_ids"] + assert len(prompt_ids) == DISCRETE_IMAGE_TOKEN_LENGTH + assert all(0 <= tid < DISCRETE_IMAGE_VOCAB_SIZE for tid in prompt_ids) + + def test_audio_decoder_extracts_audio_tokens(self): + """thinker2audio_decoder should extract discrete audio tokens.""" + audio_codes = list(range(DISCRETE_AUDIO_UNIT_0_ID, + DISCRETE_AUDIO_UNIT_0_ID + 50)) + token_ids = [100, 200] + audio_codes + [300] + + thinker_out = self._make_fake_output(token_ids) + + from types import SimpleNamespace + stage_list = {0: SimpleNamespace(engine_outputs=[thinker_out])} + + from vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni import ( + thinker2audio_decoder, + ) + results = thinker2audio_decoder(stage_list, [0]) + assert len(results) == 1 + additional = results[0]["additional_information"] + audio_tokens = additional["audio_tokens"][0] + assert len(audio_tokens) == 50 + assert all(0 <= tid < DISCRETE_AUDIO_VOCAB_SIZE for tid in audio_tokens) + + def test_vision_decoder_no_output_if_no_image_tokens(self): + """thinker2vision_decoder returns empty list when no image tokens present.""" + token_ids = [100, 200, 300] # text only + + thinker_out = self._make_fake_output(token_ids) + + from types import SimpleNamespace + stage_list = {0: SimpleNamespace(engine_outputs=[thinker_out])} + + from vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni import ( + thinker2vision_decoder, + ) + results = thinker2vision_decoder(stage_list, [0]) + assert results == [] + + def test_audio_decoder_no_output_if_no_audio_tokens(self): + """thinker2audio_decoder returns empty list when no audio tokens present.""" + token_ids = [100, 200, 300] # text only + + thinker_out = self._make_fake_output(token_ids) + + from types import SimpleNamespace + stage_list = {0: SimpleNamespace(engine_outputs=[thinker_out])} + + from vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni import ( + thinker2audio_decoder, + ) + results = thinker2audio_decoder(stage_list, [0]) + assert results == [] diff --git a/vllm_omni/diffusion/models/hyperclovax_vision/__init__.py b/vllm_omni/diffusion/models/hyperclovax_vision/__init__.py new file mode 100644 index 00000000000..fbf54a827d8 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_vision/__init__.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""HyperCLOVAX Vision Decoder diffusion model components.""" + +from vllm_omni.diffusion.models.hyperclovax_vision.hyperclovax_vision_transformer import ( + HyperCLOVAXVisionTransformer2DModel, +) +from vllm_omni.diffusion.models.hyperclovax_vision.pipeline_hyperclovax_vision import ( + HyperCLOVAXVisionPipeline, + get_hyperclovax_vision_post_process_func, +) +from vllm_omni.diffusion.models.hyperclovax_vision.vision_token_embedder import ( + VisionTokenEmbedder, +) + +__all__ = [ + "HyperCLOVAXVisionPipeline", + "HyperCLOVAXVisionTransformer2DModel", + "VisionTokenEmbedder", + "get_hyperclovax_vision_post_process_func", +] diff --git a/vllm_omni/diffusion/models/hyperclovax_vision/hyperclovax_vision_transformer.py b/vllm_omni/diffusion/models/hyperclovax_vision/hyperclovax_vision_transformer.py new file mode 100644 index 00000000000..eea0bc1c5c5 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_vision/hyperclovax_vision_transformer.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NAVER Cloud Corp. vision-decoder-api + +""" +HyperCLOVAX Vision Transformer for vision token to image generation. + +This module implements the VisionTransformer diffusion model that converts +vision token embeddings to latent representations for image generation. +""" + +import torch +import torch.nn as nn + +from vllm_omni.diffusion.data import OmniDiffusionConfig + +from .layers import ( + EmbedAND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +class HyperCLOVAXVisionTransformer2DModel(nn.Module): + """ + Vision Transformer for vision token to image generation. + + This transformer processes vision token embeddings concatenated with + noisy latents to predict noise for the diffusion process. + + Architecture: + - Input projection: (in_channels + context_in_dim) -> hidden_size + - Time embedding: 256 -> hidden_size + - Vector embedding: context_in_dim -> hidden_size + - Position embedding: EmbedAND with 3D axes + - Single stream blocks: 35 parallel attention+MLP blocks + - Output layer: hidden_size -> out_channels + + Args: + od_config: OmniDiffusionConfig containing model configuration + in_channels: Number of latent channels (default: 16) + vec_in_dim: Vision pooler output dimension (default: 1536) + context_in_dim: Vision hidden state dimension (default: 1536) + hidden_size: Transformer hidden dimension (default: 1920) + mlp_ratio: MLP expansion ratio (default: 4.0) + num_heads: Number of attention heads (default: 24) + depth_single_blocks: Number of single stream blocks (default: 35) + axes_dim: Position embedding axes dimensions (default: [8, 36, 36]) + theta: RoPE theta parameter (default: 10000) + use_patchify: Whether to use 2x2 patchification (default: False) + """ + + def __init__( + self, + od_config: OmniDiffusionConfig, + in_channels: int = 16, + vec_in_dim: int = 1536, + context_in_dim: int = 1536, + hidden_size: int = 1920, + mlp_ratio: float = 4.0, + num_heads: int = 24, + depth_single_blocks: int = 35, + axes_dim: tuple[int, int, int] = (8, 36, 36), + theta: int = 10_000, + use_patchify: bool = False, + ): + super().__init__() + + self.od_config = od_config + self.in_channels = in_channels + self.context_in_dim = context_in_dim + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_heads = num_heads + self.use_patchify = use_patchify + self.depth_single_blocks = depth_single_blocks + + if hidden_size % num_heads != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") + + pe_dim = hidden_size // num_heads + axes_dim_list = list(axes_dim) + if sum(axes_dim_list) != pe_dim: + raise ValueError(f"Got {axes_dim_list} but expected positional dim {pe_dim}") + + # Position embedding + self.pe_embedder = EmbedAND(dim=pe_dim, theta=theta, axes_dim=axes_dim_list) + + # Input projections + self.img_in = nn.Linear(in_channels + context_in_dim, hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size) + self.vector_in = MLPEmbedder(vec_in_dim, hidden_size) + + # Single stream blocks + self.single_blocks = nn.ModuleList( + [SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single_blocks)] + ) + + # Output layer + self.final_layer = LastLayer(hidden_size, 1, self.out_channels) + + def forward( + self, + img: torch.Tensor, + img_ids: torch.Tensor, + timesteps: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the transformer. + + Args: + img: Input tensor (B, L, in_channels + context_in_dim) + Concatenation of noisy latents and vision spatial features + img_ids: Position IDs tensor (B, L, 3) + timesteps: Sigma/timestep tensor (B,) in [0, 1] + y: Vision pooler output tensor (B, vec_in_dim) + + Returns: + Output tensor (B, L, out_channels) - predicted noise + """ + if img.ndim != 3: + raise ValueError("Input img tensor must have 3 dimensions.") + + # Project input + img = self.img_in(img) + + # Time and vector embedding + vec = self.time_in( + timestep_embedding(timesteps, 256).to(dtype=self.time_in.in_layer.weight.dtype, device=img.device) + ) + vec = vec + self.vector_in(y) + + # Position embedding + pe = self.pe_embedder(img_ids) + + # Single stream blocks + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + + # Final projection + img = self.final_layer(img, vec) + + return img diff --git a/vllm_omni/diffusion/models/hyperclovax_vision/layers.py b/vllm_omni/diffusion/models/hyperclovax_vision/layers.py new file mode 100644 index 00000000000..73c7f8eeca3 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_vision/layers.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NAVER Cloud Corp. vision-decoder-api + +""" +Common layers for HyperCLOVAX Vision Decoder. + +This module contains utility layers used in the VisionTransformer: +- RoPE (Rotary Position Embedding) +- EmbedAND (N-dimensional position embedding) +- MLPEmbedder (MLP for timestep and vector embeddings) +- RMSNorm (Root Mean Square Layer Normalization) +- QKNorm (Query-Key normalization) +- Modulation (Adaptive layer normalization modulation) +- SingleStreamBlock (Parallel attention and MLP block) +- LastLayer (Final projection layer) +""" + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +# Flash Attention support detection +try: + from torch.nn.attention import SDPBackend, sdpa_kernel + + FLASH_ATTN_AVAILABLE = torch.cuda.is_available() +except ImportError: + FLASH_ATTN_AVAILABLE = False + sdpa_kernel = None + SDPBackend = None + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + """Rotary Position Embedding computation.""" + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary position embedding to query and key.""" + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor) -> torch.Tensor: + """Attention with rotary position embedding and Flash Attention optimization.""" + q, k = apply_rope(q, k, pe) + + # Use Flash Attention when available for better memory efficiency and speed + if FLASH_ATTN_AVAILABLE and q.is_cuda: + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = F.scaled_dot_product_attention(q, k, v) + else: + # Fallback to default SDPA (will use best available backend) + x = F.scaled_dot_product_attention(q, k, v) + + x = rearrange(x, "B H L D -> B L (H D)") + return x + + +@torch.no_grad() +def timestep_embedding( + t: torch.Tensor, + dim: int, + max_period: float = 10000, + time_factor: float = 1000.0, +) -> torch.Tensor: + """Create sinusoidal timestep embeddings.""" + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class EmbedAND(nn.Module): + """N-dimensional position embedding.""" + + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +class MLPEmbedder(nn.Module): + """MLP for timestep and vector embeddings.""" + + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization.""" + + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(nn.Module): + """Query-Key normalization.""" + + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +@dataclass +class ModulationOut: + shift: torch.Tensor + scale: torch.Tensor + gate: torch.Tensor + + +class Modulation(nn.Module): + """Adaptive layer normalization modulation.""" + + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(F.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class SingleStreamBlock(nn.Module): + """Single stream transformer block (parallel attention and MLP).""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor) -> torch.Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + attn = attention(q, k, v, pe=pe) + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + """Final projection layer with adaptive normalization.""" + + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/vllm_omni/diffusion/models/hyperclovax_vision/pipeline_hyperclovax_vision.py b/vllm_omni/diffusion/models/hyperclovax_vision/pipeline_hyperclovax_vision.py new file mode 100644 index 00000000000..05f6d547cd3 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_vision/pipeline_hyperclovax_vision.py @@ -0,0 +1,429 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NAVER Cloud Corp. vision-decoder-api + +""" +HyperCLOVAX Vision Pipeline for vLLM-Omni. + +This pipeline converts vision tokens to images using a VisionTransformer +diffusion model. It supports: +- Vision token embedding +- Flow matching diffusion +- Autoguidance (optional transformer2) +- xDiT USP sequence parallelism +""" + +import json +import logging +import os +from collections.abc import Iterable + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import AutoencoderKL +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from einops import rearrange, repeat +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.transformers_utils.config import get_hf_file_to_dict + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from .hyperclovax_vision_transformer import HyperCLOVAXVisionTransformer2DModel +from .vision_token_embedder import VisionTokenEmbedder + +logger = logging.getLogger(__name__) + + +def get_hyperclovax_vision_post_process_func(od_config: OmniDiffusionConfig): + """ + Get post-processing function for HyperCLOVAX Vision pipeline. + + Returns a function that converts model output tensors to PIL images. + """ + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, + ) + + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + + # Load VAE config to get scale factor + vae_config_path = os.path.join(model_path, "vae/config.json") + if os.path.exists(vae_config_path): + with open(vae_config_path) as f: + config = json.load(f) + # Use scaling_factor from config, default to 8 for AutoencoderKL + vae_scale_factor = config.get("scaling_factor", 8) + else: + vae_scale_factor = 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + def post_process_func(images: torch.Tensor): + """Convert tensor images to PIL images.""" + return image_processor.postprocess(images) + + return post_process_func + + +class HyperCLOVAXVisionPipeline(nn.Module): + """ + HyperCLOVAX Vision Pipeline for vision token to image generation. + + This pipeline: + 1. Embeds vision tokens using VisionTokenEmbedder + 2. Runs flow matching diffusion with VisionTransformer + 3. Decodes latents to images using VAE + 4. Optionally applies autoguidance with transformer2 + + Args: + od_config: OmniDiffusionConfig containing model configuration + prefix: Prefix for weight loading (default: "") + """ + + @staticmethod + def get_dummy_extra() -> dict: + """Return dummy extra dict for warmup dummy run.""" + import numpy as np + # token_length=729, vocab_size=65536 per token_embedder/config.json + return {"vision_tokens": np.zeros((1, 729), dtype=np.int64)} + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.device = get_local_device() + + model = od_config.model + local_files_only = os.path.exists(model) + + def _load_component_config(subfolder: str) -> dict: + if os.path.isdir(model): + cfg_path = os.path.join(model, subfolder, "config.json") + if os.path.exists(cfg_path): + with open(cfg_path) as f: + return json.load(f) + return {} + cfg = get_hf_file_to_dict(f"{subfolder}/config.json", model) + return cfg or {} + + def _build_transformer_kwargs(cfg: dict) -> dict: + axes_dim = cfg.get("axes_dim", [8, 36, 36]) + return { + "in_channels": cfg.get("in_channels", 16), + "vec_in_dim": cfg.get("vec_in_dim", 1536), + "context_in_dim": cfg.get("context_in_dim", 1536), + "hidden_size": cfg.get("hidden_size", 1920), + "mlp_ratio": cfg.get("mlp_ratio", 4.0), + "num_heads": cfg.get("num_heads", 24), + "depth_single_blocks": cfg.get("depth_single_blocks", 35), + "axes_dim": tuple(axes_dim), + "theta": cfg.get("theta", 10000), + "use_patchify": cfg.get("use_patchify", False), + } + + transformer_cfg = _load_component_config("transformer") + + # 1. Load scheduler + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + + # 2. Load VAE + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + + # 3. Initialize token embedder + self.token_embedder = VisionTokenEmbedder( + vocab_size=65536, + embedding_dim=1536, + token_length=729, + ) + + # 4. Initialize transformer + self.transformer = HyperCLOVAXVisionTransformer2DModel( + od_config=od_config, + **_build_transformer_kwargs(transformer_cfg), + ) + + # 5. Initialize transformer2 for autoguidance (if available) + transformer2_exists = False + if os.path.isdir(model): + # Local path: check filesystem + transformer2_path = os.path.join(model, "transformer2") + transformer2_exists = os.path.exists(transformer2_path) + else: + # Remote HF repo: check if transformer2 subfolder exists + try: + from huggingface_hub import HfFileSystem + + fs = HfFileSystem() + transformer2_exists = fs.exists(f"{model}/transformer2") + except Exception: + transformer2_exists = False + + if transformer2_exists: + transformer2_cfg = _load_component_config("transformer2") + if not transformer2_cfg: + transformer2_cfg = transformer_cfg + self.transformer2 = HyperCLOVAXVisionTransformer2DModel( + od_config=od_config, + **_build_transformer_kwargs(transformer2_cfg), + ) + else: + self.transformer2 = None + + # Weight sources for vLLM loader + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ), + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="token_embedder", + revision=None, + prefix="token_embedder.", + fall_back_to_pt=True, + ), + ] + + # Add transformer2 weights if available + if self.transformer2 is not None: + self.weights_sources.append( + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer2", + revision=None, + prefix="transformer2.", + fall_back_to_pt=True, + ) + ) + + # VAE configuration + self.vae_scale_factor = 8 + self.vae_scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) + self.vae_shift_factor = getattr(self.vae.config, "shift_factor", 0.0) + + # Apply USP parallelization if configured + if od_config.parallel_config.sequence_parallel_size > 1: + try: + from .transformer_usp import parallelize_transformer + + self.transformer = parallelize_transformer(self.transformer) + if self.transformer2 is not None: + self.transformer2 = parallelize_transformer(self.transformer2) + logger.info("USP parallelization applied successfully") + except ImportError: + logger.warning("xDiT not available, skipping USP parallelization") + + self.to(self.device) + + def _prepare_latents( + self, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + """Prepare random latents for diffusion.""" + dtype = dtype or self.od_config.dtype + + latent_h = height // self.vae_scale_factor + latent_w = width // self.vae_scale_factor + latent_channels = 16 # VAE has 16 latent channels + + shape = (batch_size, latent_channels, latent_h, latent_w) + latents = torch.randn(shape, device=self.device, dtype=dtype, generator=generator) + + return latents + + def _prepare_img_ids( + self, + batch_size: int, + img_h: int, + img_w: int, + ) -> torch.Tensor: + """Prepare position IDs for the transformer.""" + img_ids = torch.zeros(img_h, img_w, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(img_h)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(img_w)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids.to(device=self.device, dtype=self.od_config.dtype) + + def _prepare_vision_spatial( + self, + vision_hidden: torch.Tensor, + img_h: int, + img_w: int, + ) -> torch.Tensor: + """ + Prepare vision spatial features for concatenation with latents. + + Interpolates vision hidden states to match latent spatial dimensions. + """ + # vision_hidden: (B, L, C) where L is typically 729 (27x27) + cond_h = cond_w = int(vision_hidden.shape[1] ** 0.5) + + # Reshape to spatial format + vision_spatial = rearrange(vision_hidden, "b (h w) c -> b c h w", h=cond_h, w=cond_w) + + # Interpolate to match latent size + vision_spatial = F.interpolate(vision_spatial, size=(img_h, img_w), mode="bilinear", align_corners=False) + + # Reshape back to sequence format + vision_spatial = rearrange(vision_spatial, "b c h w -> b (h w) c") + + return vision_spatial + + def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + """Decode latents to images using VAE.""" + latents = latents / self.vae_scaling_factor + self.vae_shift_factor + images = self.vae.decode(latents).sample + return images + + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + """ + Generate images from vision tokens. + + Args: + req: OmniDiffusionRequest containing: + - extra["vision_tokens"]: Vision token IDs (B, L) or (L,) + - height: Output image height (default: 768) + - width: Output image width (default: 768) + - num_inference_steps: Number of diffusion steps (default: 50) + - guidance_scale: Autoguidance scale (default: 0.0) + - seed: Random seed (optional) + + Returns: + DiffusionOutput with generated images + """ + # Extract vision tokens from request + vision_tokens = req.extra.get("vision_tokens") + if vision_tokens is None: + return DiffusionOutput(output=None, error="vision_tokens required in req.extra") + + # Convert to tensor if needed + if isinstance(vision_tokens, list): + vision_tokens = torch.tensor(vision_tokens, dtype=torch.long) + elif isinstance(vision_tokens, np.ndarray): + vision_tokens = torch.from_numpy(vision_tokens).long() + + if vision_tokens.ndim == 1: + vision_tokens = vision_tokens.unsqueeze(0) + + vision_tokens = vision_tokens.to(self.device) + batch_size = vision_tokens.shape[0] + + # Get parameters from request sampling_params + sp = req.sampling_params + height = (sp.height if sp.height else 768) + width = (sp.width if sp.width else 768) + num_steps = (sp.num_inference_steps if sp.num_inference_steps else 50) + guidance_scale = (sp.guidance_scale if sp.guidance_scale else 0.0) + + # Setup generator for reproducibility + generator = sp.generator + if generator is None and sp.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(sp.seed) + + dtype = self.od_config.dtype + + # 1. Vision Token Embedding + vision_cond = self.token_embedder(vision_tokens) + vision_hidden = vision_cond["vision_last_hidden_state"].to(dtype) + vision_pooler = vision_cond["vision_pooler_output"].to(dtype) + + # 2. Prepare latents + latents = self._prepare_latents(batch_size, height, width, dtype=dtype, generator=generator) + + # 3. Prepare position IDs + img_h = height // self.vae_scale_factor + img_w = width // self.vae_scale_factor + img_ids = self._prepare_img_ids(batch_size, img_h, img_w) + + # 4. Prepare vision spatial features + vision_spatial = self._prepare_vision_spatial(vision_hidden, img_h, img_w) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_steps, device=self.device) + timesteps = self.scheduler.timesteps + + # Determine if using autoguidance + use_autoguidance = self.transformer2 is not None and guidance_scale > 0 + + # 6. Denoising loop + for i, t in enumerate(timesteps): + # Prepare input: concatenate latents with vision spatial + if self.transformer.use_patchify: + x_t = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + else: + x_t = rearrange(latents, "b c h w -> b (h w) c") + + x_t = torch.cat((x_t, vision_spatial), dim=2) + + # Convert timestep to sigma + t_batch = torch.full((batch_size,), t.item(), device=self.device, dtype=torch.long) + sigma = t_batch.float() / self.scheduler.config.num_train_timesteps + + # Forward pass + pred = self.transformer( + img=x_t, + img_ids=img_ids, + timesteps=sigma, + y=vision_pooler, + ) + + # Apply autoguidance + if use_autoguidance: + pred2 = self.transformer2( + img=x_t, + img_ids=img_ids, + timesteps=sigma, + y=vision_pooler, + ) + pred = pred + guidance_scale * (pred - pred2) + + # Unpatchify prediction + if self.transformer.use_patchify: + pred = rearrange( + pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=img_h // 2, + w=img_w // 2, + ph=2, + pw=2, + ) + else: + pred = rearrange(pred, "b (h w) c -> b c h w", h=img_h, w=img_w) + + # Scheduler step + latents = self.scheduler.step(pred, t, latents, generator=generator).prev_sample + + # 7. Decode latents + images = self._decode_latents(latents) + + return DiffusionOutput(output=images) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load model weights using AutoWeightsLoader.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/hyperclovax_vision/transformer_usp.py b/vllm_omni/diffusion/models/hyperclovax_vision/transformer_usp.py new file mode 100644 index 00000000000..5136fb12e55 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_vision/transformer_usp.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NAVER Cloud Corp. vision-decoder-api + +""" +VisionTransformer USP Wrapper for xDiT Integration. + +This module provides Unified Sequence Parallelism (USP) support for the +HyperCLOVAX VisionTransformer used in vision token to image generation. + +USP enables multi-GPU acceleration by: +- Splitting input sequences across GPUs (Ulysses parallelism) +- Using ring attention patterns for long sequences +- Efficiently gathering outputs after processing +""" + +import functools + +import torch +import torch.nn as nn +from einops import rearrange + +from .layers import timestep_embedding + +# xDiT imports +try: + from xfuser.core.distributed import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, + ) + from xfuser.model_executor.layers.usp import USP + + XDIT_AVAILABLE = True +except ImportError: + XDIT_AVAILABLE = False + + +def split_sequence(tensor: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Split tensor along sequence dimension for parallel processing. + + Args: + tensor: Input tensor to split + dim: Dimension to split along (default: 1 for sequence dim) + + Returns: + Local chunk of the tensor for this rank + """ + if not XDIT_AVAILABLE or get_sequence_parallel_world_size() <= 1: + return tensor + + world_size = get_sequence_parallel_world_size() + rank = get_sequence_parallel_rank() + + chunks = torch.chunk(tensor, world_size, dim=dim) + return chunks[rank].contiguous() + + +def gather_sequence(tensor: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Gather tensor from all ranks along sequence dimension. + + Args: + tensor: Local tensor chunk + dim: Dimension to gather along (default: 1 for sequence dim) + + Returns: + Full tensor gathered from all ranks + """ + if not XDIT_AVAILABLE or get_sequence_parallel_world_size() <= 1: + return tensor + + return get_sp_group().all_gather(tensor.contiguous(), dim=dim) + + +def split_rope_embedding(pe: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + Split RoPE position embedding for sequence parallelism. + + The VisionTransformer uses 3D position encoding with axes_dim [8, 36, 36]. + The PE tensor has shape (B, 1, L, head_dim//2, 2, 2) after EmbedAND. + + Args: + pe: Position embedding tensor + seq_len: Original sequence length + + Returns: + Local chunk of position embeddings + """ + if not XDIT_AVAILABLE or get_sequence_parallel_world_size() <= 1: + return pe + + world_size = get_sequence_parallel_world_size() + rank = get_sequence_parallel_rank() + + # PE shape: (B, 1, L, D, 2, 2) where L is sequence length + # Split along dim 2 (sequence dimension) + seq_dim = 2 + chunks = torch.chunk(pe, world_size, dim=seq_dim) + return chunks[rank].contiguous() + + +def apply_rope_usp( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding with USP support. + + Args: + xq: Query tensor (B, H, L, D) + xk: Key tensor (B, H, L, D) + freqs_cis: RoPE frequencies (B, 1, L, D//2, 2, 2) + + Returns: + Tuple of rotated query and key tensors + """ + # Reshape for RoPE application + # xq: (B, H, L, D) -> (B, H, L, D//2, 1, 2) + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + + # freqs_cis: (B, 1, L, D//2, 2, 2) - contains cos and sin + # Apply rotation: x_out = x * cos + rotate(x) * sin + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + + return ( + xq_out.reshape(*xq.shape).type_as(xq), + xk_out.reshape(*xk.shape).type_as(xk), + ) + + +def parallelize_transformer(transformer: nn.Module) -> nn.Module: + """ + Parallelize VisionTransformer for sequence parallelism. + + This function wraps the transformer's forward method to: + 1. Split input sequences across GPUs + 2. Replace attention with USP attention + 3. Gather outputs after processing + + Args: + transformer: HyperCLOVAXVisionTransformer2DModel instance + + Returns: + Modified transformer with USP support + """ + if not XDIT_AVAILABLE: + return transformer + + original_forward = transformer.forward + + @functools.wraps(transformer.__class__.forward) + def usp_forward( + self, + img: torch.Tensor, + img_ids: torch.Tensor, + timesteps: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + """ + USP-enabled forward pass. + + Args: + img: Input tensor (B, L, in_channels + context_in_dim) + img_ids: Position IDs tensor (B, L, 3) + timesteps: Sigma/timestep tensor (B,) + y: Vision pooler output tensor (B, vec_in_dim) + + Returns: + Output tensor (B, L, out_channels) + """ + sp_world_size = get_sequence_parallel_world_size() + + if sp_world_size <= 1: + # Single GPU mode + return original_forward(img, img_ids, timesteps, y) + + # Split input sequences across GPUs + img_local = split_sequence(img, dim=1) + img_ids_local = split_sequence(img_ids, dim=1) + + # Run transformer with local sequences + output_local = _usp_transformer_forward(self, img_local, img_ids_local, timesteps, y) + + # Gather output from all ranks + output = gather_sequence(output_local, dim=1) + + return output + + # Bind the new forward method + usp_forward = usp_forward.__get__(transformer) + transformer.forward = usp_forward + + # Parallelize attention in single blocks + _parallelize_attention_blocks(transformer) + + return transformer + + +def _usp_transformer_forward( + transformer: nn.Module, + img: torch.Tensor, + img_ids: torch.Tensor, + timesteps: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + """ + Internal forward pass with sequence-parallel attention. + + This function reimplements the transformer forward to use USP attention + instead of standard attention. + """ + if img.ndim != 3: + raise ValueError("Input img tensor must have 3 dimensions.") + + # Project input + img = transformer.img_in(img) + + # Time and vector embedding (no splitting needed - these are per-sample) + vec = transformer.time_in( + timestep_embedding(timesteps, 256).to(dtype=transformer.time_in.in_layer.weight.dtype, device=img.device) + ) + vec = vec + transformer.vector_in(y) + + # Position embedding - compute for local sequence + pe = transformer.pe_embedder(img_ids) + + # Single stream blocks with USP attention + for block in transformer.single_blocks: + img = _usp_single_block_forward(block, img, vec, pe) + + # Final projection + img = transformer.final_layer(img, vec) + + return img + + +def _usp_single_block_forward( + block: nn.Module, + x: torch.Tensor, + vec: torch.Tensor, + pe: torch.Tensor, +) -> torch.Tensor: + """ + Single block forward with USP attention. + + This replaces the standard attention with USP attention that + handles cross-GPU communication internally. + """ + mod, _ = block.modulation(vec) + x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift + qkv, mlp = torch.split( + block.linear1(x_mod), + [3 * block.hidden_size, block.mlp_hidden_dim], + dim=-1, + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads) + q, k = block.norm(q, k, v) + + # USP attention + if XDIT_AVAILABLE and get_sequence_parallel_world_size() > 1: + # Apply RoPE to local Q, K + q, k = apply_rope_usp(q, k, pe) + + # Use xfuser's USP for efficient parallel attention + # USP handles cross-GPU communication internally + attn = USP(q, k, v, dropout_p=0.0, is_causal=False) + + attn = rearrange(attn, "B H L D -> B L (H D)") + else: + # Standard attention with RoPE + from .layers import attention + + attn = attention(q, k, v, pe=pe) + + output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +def _parallelize_attention_blocks(transformer: nn.Module) -> None: + """ + Replace attention in all single blocks with USP-enabled attention. + + This modifies the blocks in-place to use USP attention. + """ + if not hasattr(transformer, "single_blocks"): + return + + for i, block in enumerate(transformer.single_blocks): + # Store original parameters + block._usp_enabled = True + block._original_forward = block.forward + + # Create new forward that uses USP + def make_usp_block_forward(blk): + @functools.wraps(blk.__class__.forward) + def usp_block_forward(self, x, vec, pe): + return _usp_single_block_forward(self, x, vec, pe) + + return usp_block_forward + + block.forward = make_usp_block_forward(block).__get__(block) + + +def create_parallel_transformer( + transformer: nn.Module, + ulysses_degree: int = 1, + ring_degree: int = 1, +) -> nn.Module: + """ + Create a parallelized transformer with specified parallelism degrees. + + This is a convenience function that parallelizes the transformer + for sequence parallelism. + + Args: + transformer: HyperCLOVAXVisionTransformer2DModel instance + ulysses_degree: Degree of Ulysses attention parallelism + ring_degree: Degree of Ring attention parallelism + + Returns: + Parallelized transformer + """ + if not XDIT_AVAILABLE: + return transformer + + total_degree = ulysses_degree * ring_degree + world_size = get_sequence_parallel_world_size() + + if world_size != total_degree: + raise ValueError( + f"World size ({world_size}) must equal ulysses_degree * ring_degree " + f"({ulysses_degree} * {ring_degree} = {total_degree})" + ) + + return parallelize_transformer(transformer) diff --git a/vllm_omni/diffusion/models/hyperclovax_vision/vision_token_embedder.py b/vllm_omni/diffusion/models/hyperclovax_vision/vision_token_embedder.py new file mode 100644 index 00000000000..90998670d3b --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_vision/vision_token_embedder.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NAVER Cloud Corp. vision-decoder-api + +""" +Vision Token Embedder for HyperCLOVAX Vision Decoder. + +Converts discrete vision tokens to continuous embeddings. +""" + +import numpy as np +import torch +import torch.nn as nn + + +class VisionTokenEmbedder(nn.Module): + """ + Vision Token Embedder that converts discrete vision tokens to embeddings. + + This module embeds vision tokens (discrete vocabulary indices) into + continuous vector representations for the VisionTransformer. + + Args: + vocab_size: Size of the vision token vocabulary (default: 65536) + embedding_dim: Dimension of the embedding vectors (default: 1536) + token_length: Expected number of tokens per image (default: 729 for 27x27) + """ + + def __init__( + self, + vocab_size: int = 65536, + embedding_dim: int = 1536, + token_length: int = 729, + ): + super().__init__() + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.token_length = token_length + + # Main vocabulary embeddings + self.vocab_embeddings = nn.Parameter(torch.zeros(vocab_size, embedding_dim)) + + # Unconditional embedding for classifier-free guidance + self.uncond_embedding = nn.Parameter(torch.zeros(1, embedding_dim)) + + def load_vocab_embeddings(self, embeddings: torch.Tensor) -> None: + """Load vocabulary embeddings from a tensor.""" + if embeddings.shape != (self.vocab_size, self.embedding_dim): + raise ValueError( + f"Expected embeddings shape ({self.vocab_size}, {self.embedding_dim}), got {embeddings.shape}" + ) + with torch.no_grad(): + self.vocab_embeddings.copy_(embeddings) + + def forward(self, tokens: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Convert vision tokens to embeddings. + + Args: + tokens: Vision token IDs (B, L) where L is typically 729 + + Returns: + Dictionary with: + - vision_last_hidden_state: (B, L, embedding_dim) + - vision_pooler_output: (B, embedding_dim) - mean pooled + """ + # Look up embeddings + hidden_states = self.vocab_embeddings[tokens] + + # Mean pooling for pooler output + pooler_output = hidden_states.mean(dim=1) + + return { + "vision_last_hidden_state": hidden_states, + "vision_pooler_output": pooler_output, + } + + def get_uncond_embeddings(self, batch_size: int, token_length: int) -> dict[str, torch.Tensor]: + """ + Get unconditional embeddings for classifier-free guidance. + + Args: + batch_size: Batch size + token_length: Number of tokens per sample + + Returns: + Dictionary with unconditional hidden states and pooler output + """ + uncond_hidden = self.uncond_embedding.expand(batch_size, token_length, -1) + uncond_pooler = uncond_hidden.mean(dim=1) + + return { + "vision_last_hidden_state": uncond_hidden, + "vision_pooler_output": uncond_pooler, + } + + @classmethod + def from_numpy(cls, npy_path: str) -> "VisionTokenEmbedder": + """ + Create embedder from numpy file. + + Args: + npy_path: Path to .npy file containing embeddings + + Returns: + VisionTokenEmbedder instance with loaded embeddings + """ + embeddings = torch.from_numpy(np.load(npy_path)).float() + vocab_size, embedding_dim = embeddings.shape + + embedder = cls( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + token_length=729, + ) + embedder.load_vocab_embeddings(embeddings) + return embedder diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 6b6a9c7278c..d86e9b1f0ed 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeTextConfig -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_text_config from vllm.v1.engine.async_llm import AsyncEngineArgs @@ -103,21 +102,18 @@ def create_model_config(self) -> OmniModelConfig: return omni_config + @property + def output_modality(self) -> OutputModality: + """Parse engine_output_type into a type-safe OutputModality flag.""" + return OutputModality.from_string(self.engine_output_type) + @dataclass class AsyncOmniEngineArgs(AsyncEngineArgs): - """Async engine arguments for omni models, extending base AsyncEngineArgs. - Adds omni-specific configuration fields for multi-stage pipeline - processing and output type specification in async contexts. - Args: - stage_id: Identifier for the stage in a multi-stage pipeline (default: 0) - model_stage: Stage type identifier, e.g., "thinker" or "talker" - (default: "thinker") - model_arch: Model architecture name - (default: "Qwen2_5OmniForConditionalGeneration") - engine_output_type: Optional output type specification for the engine. - Used to route outputs to appropriate processors (e.g., "image", - "audio", "latents"). If None, output type is inferred. + """Async engine arguments for omni LLM stages. + + Extends AsyncEngineArgs with omni-specific multi-stage pipeline fields. + Used when launching LLM stages (stage_type=llm) within an async context. """ stage_id: int = 0 @@ -125,53 +121,22 @@ class AsyncOmniEngineArgs(AsyncEngineArgs): model_arch: str = "Qwen2_5OmniForConditionalGeneration" engine_output_type: str | None = None hf_config_name: str | None = None - - def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: - # transformers' get_text_config method is used to get the text config from thinker_config. - # to handle the case that each model stage has their own text config, - # we need to draw the text config from the corresponding model stage. - hf_config = config_dict["hf_config"] - hf_config_name = config_dict["hf_config_name"] - try: - # Try to get the stage-specific config (e.g., thinker_config, talker_config) - stage_config = getattr(hf_config, hf_config_name) - return stage_config.get_text_config() - except AttributeError: - # Fallback: if the attribute doesn't exist, use the default get_hf_text_config - logger.warning( - f"Config attribute '{hf_config_name}' not found in hf_config, " - "falling back to default get_hf_text_config" - ) - return get_hf_text_config(hf_config) - - def _ensure_omni_models_registered(self): - if hasattr(self, "_omni_models_registered"): - return True - register_omni_models_to_vllm() - self._omni_models_registered = True - return True - - def create_model_config(self) -> OmniModelConfig: - # register omni models to avoid model not found error - self._ensure_omni_models_registered() - # First, get the base ModelConfig from the parent class - base_config = super().create_model_config() - - # Create OmniModelConfig by copying all base config attributes - # and adding the new omni-specific fields - config_dict = base_config.__dict__.copy() - - # Add the new omni-specific fields - config_dict["stage_id"] = self.stage_id - config_dict["model_stage"] = self.model_stage - config_dict["model_arch"] = self.model_arch - config_dict["engine_output_type"] = self.engine_output_type - - config_dict["hf_config_name"] = self.hf_config_name - if self.hf_config_name is not None: - config_dict["hf_text_config"] = self.draw_hf_text_config(config_dict) - # Create and return the OmniModelConfig instance - omni_config = OmniModelConfig(**config_dict) - omni_config.hf_config.architectures = omni_config.architectures - - return omni_config + custom_process_next_stage_input_func: str | None = None + stage_connector_spec: dict[str, Any] = field(default_factory=dict) + async_chunk: bool = False + omni_kv_config: dict | None = None + quantization_config: Any | None = None + worker_type: str | None = None + + def __post_init__(self) -> None: + load_omni_general_plugins() + super().__post_init__() + + def create_engine_config(self, usage_context=None, **kwargs): + """Create engine config, injecting model_arch into hf_overrides if set.""" + if self.model_arch: + if self.hf_overrides is None: + self.hf_overrides = {} + if isinstance(self.hf_overrides, dict): + self.hf_overrides.setdefault("architectures", [self.model_arch]) + return super().create_engine_config(usage_context=usage_context, **kwargs) diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py index 688387452c3..e0bfd1016dd 100644 --- a/vllm_omni/engine/input_processor.py +++ b/vllm_omni/engine/input_processor.py @@ -1,284 +1,32 @@ -import time -from collections.abc import Mapping -from typing import Any, cast +"""OmniInputProcessor: extends vLLM InputProcessor with OmniInputPreprocessor.""" -import torch from vllm.config import VllmConfig -from vllm.inputs import ProcessorInputs, PromptType -from vllm.inputs.parse import split_enc_dec_inputs -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict -from vllm.multimodal.utils import argsort_mm_positions -from vllm.platforms import current_platform -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams -from vllm.tokenizers import TokenizerLike -from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.renderers import BaseRenderer from vllm.v1.engine.input_processor import InputProcessor -from vllm_omni.engine import ( - AdditionalInformationEntry, - AdditionalInformationPayload, - OmniEngineCoreRequest, - PromptEmbedsPayload, -) from vllm_omni.inputs.preprocess import OmniInputPreprocessor -logger = init_logger(__name__) - class OmniInputProcessor(InputProcessor): - """Processor for omni models, handling multimodal inputs and embeddings. - - Extends the base vLLM Processor with support for processing prompt - embeddings and additional information payloads, enabling direct transfer - of pre-computed embeddings between pipeline stages. + """InputProcessor for omni models. - Args: - vllm_config: Global vLLM configuration - tokenizer: Tokenizer instance for text processing - mm_registry: Multi-modal registry for processing multimodal inputs + Extends the base vLLM InputProcessor by replacing the default + InputPreprocessor with OmniInputPreprocessor, which handles + omni-specific input types (prompt embeddings, additional information). """ - @staticmethod - def _dtype_to_name(dtype: torch.dtype) -> str: - """Convert torch dtype to string representation. - - Args: - dtype: PyTorch dtype to convert - - Returns: - String representation of the dtype (e.g., "float32", "int64") - """ - mapping = { - torch.float32: "float32", - torch.float: "float32", - torch.float16: "float16", - torch.half: "float16", - torch.bfloat16: "bfloat16", - torch.float64: "float64", - torch.double: "float64", - torch.int64: "int64", - torch.long: "int64", - torch.int32: "int32", - torch.int: "int32", - torch.int16: "int16", - torch.short: "int16", - torch.int8: "int8", - torch.uint8: "uint8", - torch.bool: "bool", - } - return mapping.get(dtype, str(dtype).replace("torch.", "")) - def __init__( self, vllm_config: VllmConfig, - tokenizer: TokenizerLike, + renderer: BaseRenderer | None = None, + *, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - super().__init__(vllm_config, tokenizer, mm_registry) + ) -> None: + super().__init__(vllm_config, renderer, mm_registry=mm_registry) + # Replace the base InputPreprocessor with OmniInputPreprocessor self.input_preprocessor = OmniInputPreprocessor( - self.model_config, - self.tokenizer, - mm_registry, - mm_processor_cache=self.mm_processor_cache, - ) - - def process_inputs( - self, - request_id: str, - prompt: PromptType, - params: SamplingParams | PoolingParams, - arrival_time: float | None = None, - lora_request: LoRARequest | None = None, - tokenization_kwargs: dict[str, Any] | None = None, - trace_headers: Mapping[str, str] | None = None, - priority: int = 0, - data_parallel_rank: int | None = None, - ) -> tuple[str | None, OmniEngineCoreRequest]: - """Process input prompt into an engine core request. - - Converts a prompt (text, tokens, or multimodal) into an - OmniEngineCoreRequest that can be processed by the engine. - Handles prompt embeddings and additional information payloads - for direct transfer between stages. - - Args: - request_id: Unique identifier for this request - prompt: Input prompt (text, token IDs, embeddings, or multimodal) - params: Sampling or pooling parameters for generation - arrival_time: Optional arrival timestamp (defaults to current time) - lora_request: Optional LoRA adapter request - tokenization_kwargs: Optional additional tokenization arguments - trace_headers: Optional tracing headers for observability - priority: Request priority (higher values processed first) - data_parallel_rank: Optional data parallel rank for distributed - inference - - Returns: - Tuple of (prompt_string, OmniEngineCoreRequest) where: - - prompt_string: The original prompt as a string, or None if - using embeddings - - OmniEngineCoreRequest: Processed request ready for the engine - - Raises: - ValueError: If data_parallel_rank is out of range or prompt_embeds - has incorrect shape - """ - self._validate_lora(lora_request) - self._validate_params(params) - - data_parallel_size = self.vllm_config.parallel_config.data_parallel_size - if data_parallel_rank is not None and not (0 <= data_parallel_rank < data_parallel_size): - raise ValueError(f"data_parallel_rank {data_parallel_rank} is out of range [0, {data_parallel_size}).") - - if arrival_time is None: - arrival_time = time.time() - - # Optionally generate multimodal hash overrides to avoid hashing - # multimodal data items by their content as their identifiers. - - # NOTE: when users explicitly turn off BOTH prefix caching and input - # processing caching, no multimodal features or embeddings will be - # reused across requests, therefore identifying multimodal data items - # by their content is no longer necessary, and we create uuids with - # request id-modality-index as multimodal hash overrides. - if ( - self.model_config.multimodal_config - and self.model_config.multimodal_config.mm_processor_cache_gb == 0 - and not self.cache_config.enable_prefix_caching - ): - mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) - else: - # Otherwise, use user-provided uuids as multimodal hash overrides - # if provided. - self._validate_multi_modal_uuids(prompt) - if isinstance(prompt, dict): - mm_uuids = cast(MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")) - else: - mm_uuids = None - - # Process inputs, which includes: - # 1. Tokenize text prompt, with LoRA request if one exists. - # 2. For multimodal models with a merged preprocessor, preprocess - # multimodal data and expand prompt token ids accordingly. - processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( - prompt, - tokenization_kwargs=tokenization_kwargs, - mm_uuids=mm_uuids, - ) - - current_platform.validate_request( - prompt=prompt, - params=params, - processed_inputs=processed_inputs, - ) - - eos_token_id = self.input_preprocessor.get_eos_token_id() - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - self._validate_model_inputs(encoder_inputs, decoder_inputs) - - # Mypy does not always properly infer the types of some elements of - # discriminated unions of TypedDicts, because of how it handles - # inheritance of TypedDict. If we explicitly extract the items we want - # we can avoid type errors from using `dict.get` later in the method. - _prompt_str: str | None = None if decoder_inputs["type"] == "embeds" else decoder_inputs.get("prompt") - prompt_token_ids = decoder_inputs["prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None - prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs["type"] == "embeds" else None - - sampling_params = None - pooling_params = None - if isinstance(params, SamplingParams): - # TODO: can we avoid cloning here in multiproc case? - sampling_params = params.clone() - # If unset max tokens, then generate up to the max_model_len. - if sampling_params.max_tokens is None: - seq_len = length_from_prompt_token_ids_or_embeds(prompt_token_ids, prompt_embeds) - sampling_params.max_tokens = self.model_config.max_model_len - seq_len - sampling_params.update_from_generation_config(self.generation_config_fields, eos_token_id) - if self.tokenizer is not None: - sampling_params.update_from_tokenizer(self.tokenizer) - else: - pooling_params = params.clone() - - # Multimodal related. - mm_features: list[MultiModalFeatureSpec] | None = None - - if decoder_inputs["type"] == "multimodal": - decoder_mm_inputs = decoder_inputs["mm_kwargs"] - decoder_mm_positions = decoder_inputs["mm_placeholders"] - decoder_mm_hashes = decoder_inputs["mm_hashes"] - - # Merge and flatten multimodal placeholders, hashes and inputs - # from dictionaries to lists, and sort them by each item's position - # in the input sequence. - sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - - mm_features = [] - for modality, idx in sorted_mm_idxs: - mm_features.append( - MultiModalFeatureSpec( - data=decoder_mm_inputs[modality][idx], - modality=modality, - identifier=decoder_mm_hashes[modality][idx], - mm_position=decoder_mm_positions[modality][idx], - ) - ) - - # Serialize prompt_embeds and additional_information if provided - # (direct-transfer path) - prompt_embeds_payload: PromptEmbedsPayload | None = None - additional_information_payload: AdditionalInformationPayload | None = None - if "prompt_embeds" in decoder_inputs: # type: ignore[operator] - pe: torch.Tensor = decoder_inputs["prompt_embeds"] # type: ignore[index] - if pe.ndim != 2: - raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size)") - # Move to CPU and ensure contiguous memory for stable serialization - pe_cpu = pe.detach().to("cpu").contiguous() - seq_len, hidden_size = pe_cpu.shape - dtype_str = self._dtype_to_name(pe_cpu.dtype) - data_bytes = pe_cpu.numpy().tobytes() - prompt_embeds_payload = PromptEmbedsPayload( - data=data_bytes, - shape=[int(seq_len), int(hidden_size)], - dtype=dtype_str, - ) - if "additional_information" in decoder_inputs: # type: ignore[operator] - raw_info: dict[str, Any] = decoder_inputs["additional_information"] # type: ignore[index] # noqa: E501 - entries: dict[str, AdditionalInformationEntry] = {} - for key, value in raw_info.items(): - if isinstance(value, torch.Tensor): - v_cpu = value.detach().to("cpu").contiguous() - dtype_str = self._dtype_to_name(v_cpu.dtype) - data_bytes = v_cpu.numpy().tobytes() - entry = AdditionalInformationEntry( - tensor_data=data_bytes, - tensor_shape=[int(x) for x in list(v_cpu.shape)], - tensor_dtype=dtype_str, - ) - elif isinstance(value, list): - entry = AdditionalInformationEntry(list_data=value) - else: - raise ValueError("additional_information values must be Tensor or list") - entries[key] = entry - additional_information_payload = AdditionalInformationPayload(entries=entries) - - return OmniEngineCoreRequest( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - mm_features=mm_features, - sampling_params=sampling_params, - pooling_params=pooling_params, - eos_token_id=eos_token_id, - arrival_time=arrival_time, - lora_request=lora_request, - cache_salt=decoder_inputs.get("cache_salt"), - priority=priority, - data_parallel_rank=data_parallel_rank, - trace_headers=trace_headers, - prompt_embeds=prompt_embeds_payload, - additional_information=additional_information_payload, + vllm_config, + renderer=self.renderer, + mm_registry=mm_registry, ) diff --git a/vllm_omni/entrypoints/cli/main.py b/vllm_omni/entrypoints/cli/main.py index 6a65d9d6cde..b3ec90a6edd 100644 --- a/vllm_omni/entrypoints/cli/main.py +++ b/vllm_omni/entrypoints/cli/main.py @@ -43,6 +43,7 @@ def main(): for cmd in new_cmds: cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) cmds[cmd.name] = cmd + sys.argv = [a for a in sys.argv if a != "--omni"] args = parser.parse_args() if args.subparser in cmds: cmds[args.subparser].validate(args) diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 030fe4ec5de..c2482ae5f15 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -3,47 +3,66 @@ import json import multiprocessing as mp import os +import threading import time import uuid import weakref from collections.abc import Callable, Generator, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import asdict -from pprint import pformat -from typing import Any +from typing import Any, Literal, overload +import huggingface_hub +import msgspec.msgpack +import zmq from omegaconf import OmegaConf from tqdm.auto import tqdm -from vllm.inputs import PromptType +from vllm import SamplingParams from vllm.logger import init_logger +from vllm.utils.network_utils import make_zmq_socket +from vllm.v1.utils import get_engine_client_zmq_addr -from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.distributed.omni_connectors import ( get_stage_connector_config, initialize_orchestrator_connectors, ) from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector +from vllm_omni.distributed.omni_connectors.utils.initialization import ( + resolve_omni_kv_config_for_stage, +) from vllm_omni.distributed.ray_utils.utils import ( create_placement_group, get_ray_queue_class, try_close_ray, ) -from vllm_omni.entrypoints.log_utils import OrchestratorMetrics from vllm_omni.entrypoints.omni_stage import OmniStage from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load from vllm_omni.entrypoints.utils import ( get_final_stage_id_for_e2e, - load_stage_configs_from_model, - load_stage_configs_from_yaml, - resolve_model_config_path, + inject_omni_kv_config, + load_and_resolve_stage_configs, +) +from vllm_omni.entrypoints.zmq_utils import ZmqQueue +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams +from vllm_omni.metrics import OrchestratorAggregator, StageRequestStats +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, ) from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) -def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): +def _weak_close_cleanup( + stage_list, + stage_in_queues, + stage_out_queues, + ray_pg, + zmq_ctx=None, + handshake_stop: threading.Event | None = None, + zmq_handshake_socket: zmq.Socket | None = None, + handshake_thread: threading.Thread | None = None, +): """Weak reference cleanup function for OmniBase instances.""" if stage_list: for q in stage_in_queues: @@ -51,6 +70,13 @@ def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): q.put_nowait(SHUTDOWN_TASK) except Exception as e: logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() + for q in stage_out_queues: + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() for stage in stage_list: try: stage.stop_stage_worker() @@ -58,30 +84,59 @@ def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): logger.warning(f"Failed to stop stage worker: {e}") try_close_ray(ray_pg) + # Gracefully shutdown handshake server thread + if handshake_stop is not None: + handshake_stop.set() + if handshake_thread is not None: + handshake_thread.join(timeout=2.0) + if handshake_thread.is_alive(): + logger.warning("Handshake server thread did not terminate gracefully within timeout") + + # Close ZMQ resources after thread has exited + if zmq_handshake_socket is not None: + zmq_handshake_socket.close(0) + if zmq_ctx is not None: + zmq_ctx.term() + def _dummy_snapshot_download(model_id): return model_id def omni_snapshot_download(model_id) -> str: + # If it's already a local path, just return it + if os.path.exists(model_id): + return model_id # TODO: this is just a workaround for quickly use modelscope, we should support # modelscope in weight loading feature instead of using `snapshot_download` if os.environ.get("VLLM_USE_MODELSCOPE", False): from modelscope.hub.snapshot_download import snapshot_download return snapshot_download(model_id) - else: - return _dummy_snapshot_download(model_id) + # For other cases (Hugging Face), perform a real download to ensure all + # necessary files (including *.pt for audio/diffusion) are available locally + # before stage workers are spawned. This prevents initialization timeouts. + # Return the original model_id so that model_config.model preserves + # HuggingFace semantics (e.g. "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice") + # instead of the resolved cache path. + try: + download_weights_from_hf_specific( + model_name_or_path=model_id, + cache_dir=None, + allow_patterns=["*"], + require_all=True, + ) + except huggingface_hub.errors.RepositoryNotFoundError: + logger.warning(f"Repository not found for '{model_id}'.") + return model_id class OmniBase: """Base class for serving Omni models. Args: - *args: Variable length argument list. - - args[0]: Model name or path to load. + model: Model name or path to load. **kwargs: Arbitrary keyword arguments. - - model: Model name or path to load (if not in args). - stage_configs_path: Optional path to YAML file containing stage configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging @@ -98,23 +153,28 @@ class OmniBase: - Additional keyword arguments passed to stage engines. """ - def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: - model = args[0] if args else kwargs.get("model", "") - assert model != "", "Null model id detected, please specify a model id." + def __init__(self, model: str, **kwargs: Any) -> None: model = omni_snapshot_download(model) - if args: - args[0] = model - elif kwargs.get("model", "") != "": - kwargs["model"] = model + kwargs["model"] = model + self._model = model # store for use in fallback processors init # Stage management attributes self.stage_list: list[OmniStage] = [] - self._stage_in_queues: list[mp.Queue] = [] - self._stage_out_queues: list[mp.Queue] = [] + self._stage_in_queues: list[Any] = [] + self._stage_out_queues: list[Any] = [] self._stages_ready: set[int] = set() self._ray_pg = None self._queue_cls = None self._ctx = None + self._zmq_ctx: zmq.Context | None = None + self._zmq_master_address: str | None = None + self._zmq_master_port: int | None = None + self._zmq_handshake_socket: zmq.Socket | None = None + self._handshake_thread: threading.Thread | None = None + self._handshake_stop: threading.Event | None = None + self._handshake_endpoints: dict[int, tuple[str, str]] = {} + self._handshake_seen: set[int] = set() # Track which stage IDs have completed ZMQ handshake + self._single_stage_id: int | None = None # Optional: deploy only a specific stage ID # Initialize stages - each stage will create appropriate instance based on stage_type # Stage workers will automatically create OmniLLM or OmniDiffusion instances @@ -190,6 +250,54 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" return default_stage_cfg + def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[str, list[Any]]: + """Resolve stage configs and inject defaults shared by orchestrator/headless.""" + # TODO(wuhang): + # Remove kwargs as parameters in the future. + # Use dataclass directly for engine args. + + stage_configs_path = kwargs.get("stage_configs_path", None) + + # TTS-specific CLI overrides + self.tts_max_instructions_length: int | None = kwargs.get("tts_max_instructions_length", None) + + # Load stage configurations from YAML + config_path, stage_configs = load_and_resolve_stage_configs( + model, + stage_configs_path, + kwargs, + default_stage_cfg_factory=lambda: self._create_default_diffusion_stage_cfg(kwargs), + ) + + # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. + for cfg in stage_configs: + try: + if getattr(cfg, "stage_type", None) != "diffusion": + continue + if not hasattr(cfg, "engine_args") or cfg.engine_args is None: + cfg.engine_args = OmegaConf.create({}) + if kwargs.get("lora_path") is not None: + if not hasattr(cfg.engine_args, "lora_path") or cfg.engine_args.lora_path is None: + cfg.engine_args.lora_path = kwargs["lora_path"] + lora_scale = kwargs.get("lora_scale") + if lora_scale is None: + # Backwards compatibility for older callers. + lora_scale = kwargs.get("static_lora_scale") + if lora_scale is not None: + if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: + cfg.engine_args.lora_scale = lora_scale + quantization_config = kwargs.get("quantization_config") + if quantization_config is not None: + if ( + not hasattr(cfg.engine_args, "quantization_config") + or cfg.engine_args.quantization_config is None + ): + cfg.engine_args.quantization_config = quantization_config + except Exception as e: + logger.warning("Failed to inject LoRA config for stage: %s", e) + + return config_path, stage_configs + def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: """Initialize stage list management.""" stage_init_timeout = kwargs.get("stage_init_timeout", 20) @@ -198,24 +306,16 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: worker_backend = kwargs.get("worker_backend", "multi_process") ray_address = kwargs.get("ray_address", None) batch_timeout = kwargs.get("batch_timeout", 10) - stage_configs_path = kwargs.get("stage_configs_path", None) log_stats = kwargs.get("log_stats", False) + self._single_stage_id = kwargs.get("stage_id", None) + self._zmq_master_address = kwargs.get("omni_master_address", None) + if self._zmq_master_address is None: + self._zmq_master_address = "127.0.0.1" + logger.info("No omni_master_address provided, defaulting to localhost (127.0.0.1)") + self._zmq_master_port = kwargs.get("omni_master_port", None) - ### base engine args - tokenizer = kwargs.get("tokenizer", None) - - base_engine_args = {"tokenizer": tokenizer} if tokenizer is not None else None - - # Load stage configurations from YAML - if stage_configs_path is None: - self.config_path = resolve_model_config_path(model) - self.stage_configs = load_stage_configs_from_model(model, base_engine_args=base_engine_args) - if not self.stage_configs: - default_stage_cfg = self._create_default_diffusion_stage_cfg(kwargs) - self.stage_configs = OmegaConf.create(default_stage_cfg) - else: - self.config_path = stage_configs_path - self.stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=base_engine_args) + # Resolve stage configs shared by orchestrator/headless paths. + self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs) # Initialize connectors self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( @@ -223,11 +323,13 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: ) # Initialize stats paths - self._enable_stats: bool = bool(log_stats) + self.log_stats: bool = bool(log_stats) self.worker_backend = worker_backend self.ray_address = ray_address self.batch_timeout = batch_timeout + # async chunk remains the same for each stage + self.async_chunk = self._is_async_chunk_enable(self.stage_configs) # Build OmniStage instances in parallel, preserve original order def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: @@ -243,7 +345,7 @@ def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: self.stage_list = [st for _, st in results] self.default_sampling_params_list = [st.default_sampling_params for st in self.stage_list] self.output_modalities = [st.final_output_type for st in self.stage_list] - logger.debug(f"[{self._name}] Loaded {len(self.stage_list)} stages") + logger.info(f"[{self._name}] Loaded {len(self.stage_list)} stages") if self.worker_backend == "ray": self._queue_cls = get_ray_queue_class() @@ -257,6 +359,11 @@ def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: # Wait for all stages to report readiness before seeding self._wait_for_stages_ready(timeout=init_timeout) + def _is_async_chunk_enable(self, stage_args: list) -> bool: + """get async chunk flag""" + engine_args = getattr(stage_args[0], "engine_args", None) + return bool(getattr(engine_args, "async_chunk", False)) + def _start_stages(self, model: str) -> None: """Start all stage processes.""" if self.worker_backend == "ray": @@ -264,10 +371,38 @@ def _start_stages(self, model: str) -> None: self._ray_pg = create_placement_group( number_of_stages=len(self.stage_list), address=self.ray_address, strategy="PACK" ) + else: + # Initialize ZMQ context + if self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + + # Allocate endpoints for each stage + total_stages = len(self.stage_configs) + self._handshake_endpoints = {} + + # If --stage-id is not set, use local_only mode + local_only = self._single_stage_id is None + + for sid in range(total_stages): + in_endpoint = get_engine_client_zmq_addr(local_only=local_only, host=self._zmq_master_address) + out_endpoint = get_engine_client_zmq_addr(local_only=local_only, host=self._zmq_master_address) + self._handshake_endpoints[sid] = (in_endpoint, out_endpoint) + logger.debug( + f"[{self._name}] Allocated endpoints for stage-{sid}: in={in_endpoint}, out={out_endpoint}" + ) + + # Start handshake server + self.start_handshake_server() for stage_id, stage in enumerate[OmniStage](self.stage_list): - in_q = self._queue_cls() - out_q = self._queue_cls() + if self.worker_backend == "ray": + in_q = self._queue_cls() + out_q = self._queue_cls() + else: + in_endpoint, out_endpoint = self._handshake_endpoints[stage_id] + in_q = ZmqQueue(self._zmq_ctx, zmq.PUSH, bind=in_endpoint) + out_q = ZmqQueue(self._zmq_ctx, zmq.PULL, bind=out_endpoint) + self._stage_in_queues.append(in_q) self._stage_out_queues.append(out_q) stage.attach_queues(in_q, out_q) @@ -277,6 +412,24 @@ def _start_stages(self, model: str) -> None: stage_id, ) + # Inject YAML-resolved connector config into omni_kv_config for + # in-engine usage (GPU model runner reads model_config.omni_kv_config). + try: + omni_conn_cfg, omni_from, omni_to = resolve_omni_kv_config_for_stage( + self.omni_transfer_config, stage_id + ) + if omni_conn_cfg: + inject_omni_kv_config(stage, omni_conn_cfg, omni_from, omni_to) # type: ignore + + except Exception as e: + logger.debug("[Omni] Failed to inject omni connector config into stage-%s: %s", stage_id, e) + + if self._single_stage_id is not None and stage_id != int(self._single_stage_id): + logger.info( + f"[{self._name}] Skipping initialization of stage-{stage_id} worker due to single_stage_id setting" + ) + continue + stage.init_stage_worker( model, is_async=self.is_async, @@ -286,6 +439,7 @@ def _start_stages(self, model: str) -> None: connectors_config=stage_connectors_config, worker_backend=self.worker_backend, ray_placement_group=self._ray_pg, + ignore_runtime_config=True if self._single_stage_id is not None else False, ) logger.debug(f"[{self._name}] Stage-{stage_id} process started") @@ -296,6 +450,9 @@ def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str def _wait_for_stages_ready(self, timeout: int = 120) -> None: """Wait for all stages to report readiness with optimized polling.""" + if self._single_stage_id is not None and self.worker_backend != "ray": + timeout = self._wait_for_handshakes(timeout) + num_stages = len(self.stage_list) deadline = time.time() + max(0, int(timeout)) @@ -329,6 +486,7 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: ) suggestions = [ + f"Ignore this warning if the model weight download / load from disk time is longer than {timeout}s.", "Verify GPU/device assignment in config (runtime.devices) is correct.", "Check GPU/host memory availability; reduce model or batch size if needed.", "Check model weights path and network reachability (if loading remotely).", @@ -337,7 +495,23 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: formatted_suggestions = "\n".join(f" {i + 1}) {msg}" for i, msg in enumerate(suggestions)) - logger.error(f"[{self._name}] Stage initialization failed. Troubleshooting Steps:\n{formatted_suggestions}") + logger.warning(f"[{self._name}] Stage initialization timeout. Troubleshooting Steps:\n{formatted_suggestions}") + + def _is_profiler_enabled(self, stage_id: int) -> bool: + """Check if profiler config is set for a given stage.""" + stage = self.stage_list[stage_id] + # For diffusion stages, profiling is controlled by VLLM_TORCH_PROFILER_DIR env var + if stage.stage_type == "diffusion": + return True + # For LLM stages, check if profiler_config is set in engine_args + engine_args = getattr(stage.stage_config, "engine_args", None) + if engine_args is None: + return False + profiler_config = getattr(engine_args, "profiler_config", None) + if profiler_config is None: + return False + profiler = getattr(profiler_config, "profiler", None) + return profiler is not None def start_profile(self, stages: list[int] | None = None) -> None: """Start profiling for specified stages. @@ -363,6 +537,13 @@ def start_profile(self, stages: list[int] | None = None) -> None: for stage_id in stages: if stage_id < len(self.stage_list): + if not self._is_profiler_enabled(stage_id): + logger.info( + "[%s] Skipping start_profile for stage-%s: profiler config not set", + self._name, + stage_id, + ) + continue try: self.stage_list[stage_id].submit({"type": OmniStageTaskType.PROFILER_START}) logger.info("[%s] Sent start_profile to stage-%s", self._name, stage_id) @@ -386,6 +567,13 @@ def stop_profile(self, stages: list[int] | None = None) -> dict: for stage_id in stages: if stage_id < len(self.stage_list): + if not self._is_profiler_enabled(stage_id): + logger.info( + "[%s] Skipping stop_profile for stage-%s: profiler config not set", + self._name, + stage_id, + ) + continue stage = self.stage_list[stage_id] # Check if the stage object has our new bridge method @@ -445,6 +633,129 @@ def close(self) -> None: if hasattr(self, "_weak_finalizer"): self._weak_finalizer() + def _process_handshake_message(self, msg: Any) -> dict[str, Any]: + """Process incoming handshake message and generate response. + + Args: + msg: Decoded message from client + + Returns: + Response dictionary with ok status and either endpoints or error + """ + if not isinstance(msg, dict) or msg.get("type") != "handshake": + return {"ok": False, "error": "invalid handshake payload"} + + try: + stage_id = int(msg.get("stage_id")) + except (TypeError, ValueError) as e: + return {"ok": False, "error": f"invalid stage_id: {e}"} + + endpoints = self._handshake_endpoints.get(stage_id) + if endpoints is None: + return {"ok": False, "error": f"unknown stage_id: {stage_id}"} + + # Mark stage as seen and prepare success response + self._handshake_seen.add(stage_id) + in_endpoint, out_endpoint = endpoints + + logger.info( + "[%s] Handshake received from stage-%s", + self._name, + stage_id, + ) + + return { + "ok": True, + "in_endpoint": in_endpoint, + "out_endpoint": out_endpoint, + } + + def _run_handshake_server_loop(self) -> None: + """Main loop for handshake server - polls for messages and responds.""" + poller = zmq.Poller() + poller.register(self._zmq_handshake_socket, zmq.POLLIN) + + try: + while not self._handshake_stop.is_set(): + events = poller.poll(1000) + has_message = any(sock == self._zmq_handshake_socket and event == zmq.POLLIN for sock, event in events) + if not has_message: + continue + + msg = msgspec.msgpack.decode(self._zmq_handshake_socket.recv()) + response = msgspec.msgpack.encode(self._process_handshake_message(msg)) + self._zmq_handshake_socket.send(response) + finally: + poller.unregister(self._zmq_handshake_socket) + + def start_handshake_server(self) -> None: + """Start the ZMQ handshake server. + + The handshake server allows distributed stages to discover their + queue endpoints by querying the orchestrator with their stage_id. + Skips starting if the server is already running or ZMQ is not initialized. + """ + # Skip if already running or ZMQ not initialized + if self._handshake_thread is not None or self._zmq_ctx is None: + return + + # Skip if master address/port not configured + if not self._zmq_master_address or self._zmq_master_port is None: + return + + # Create server endpoint and socket + endpoint = get_engine_client_zmq_addr( + local_only=False, host=self._zmq_master_address, port=int(self._zmq_master_port) + ) + + self._handshake_stop = threading.Event() + self._zmq_handshake_socket = make_zmq_socket(self._zmq_ctx, endpoint, zmq.REP, bind=True, linger=5000) + + # Start server thread + self._handshake_thread = threading.Thread( + target=self._run_handshake_server_loop, daemon=True, name="zmq-handshake-server" + ) + self._handshake_thread.start() + + def _wait_for_handshakes(self, timeout: int = 120) -> int: + """Wait for handshakes from all expected stages. + + Args: + timeout: Timeout in seconds for waiting for handshakes. Default is 120s. + + Returns: + Remaining timeout in seconds after waiting for handshakes. + """ + total_stages = len(self.stage_configs) + expected = set(range(total_stages)) - {int(self._single_stage_id)} + if not expected: + return timeout + + deadline = time.time() + max(0, int(timeout)) + logger.info(f"[{self._name}] Waiting for handshakes from stages: {expected} (timeout: {timeout}s)") + + # NOTE: _handshake_seen may be updated from the handshake server thread. + # It is intentionally used here without additional locking because: + # - _handshake_seen only ever grows (stages are added but never removed), and + # - we only check membership and set inclusion relative to `expected`. + # Under these monotonic semantics and the CPython GIL, concurrent reads/writes + # are safe for this usage and cannot violate correctness: we may observe a + # slightly stale view, but the loop condition remains valid and eventually + # becomes true once all expected stages have handshaked or the timeout elapses. + while not expected.issubset(self._handshake_seen) and time.time() < deadline: + time.sleep(1.0) + + remaining_timeout = max(0, int(deadline - time.time())) + + if not expected.issubset(self._handshake_seen): + missing = sorted(expected - self._handshake_seen) + logger.warning( + f"[{self._name}] Handshake timeout: {len(self._handshake_seen)}/{len(expected)} " + f"stages completed handshake. Missing stages: {missing}" + ) + + return remaining_timeout + @property def _name(self) -> str: return "OmniBase" @@ -458,10 +769,8 @@ class Omni(OmniBase): """Unified entrypoint for both LLM and Diffusion models for better usability. Args: - *args: Variable length argument list. - - args[0]: Model name or path to load. + model: Model name or path to load. **kwargs: Arbitrary keyword arguments. - - model: Model name or path to load (if not in args). - stage_configs_path: Optional path to YAML file containing stage configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging @@ -483,8 +792,8 @@ class Omni(OmniBase): >>> print(outputs) """ - def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: - super().__init__(*args, **kwargs) + def __init__(self, model: str, **kwargs: Any) -> None: + super().__init__(model, **kwargs) # Register weak reference cleanup (called on garbage collection) self._weak_finalizer = weakref.finalize( @@ -492,11 +801,39 @@ def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: _weak_close_cleanup, self.stage_list, self._stage_in_queues, + self._stage_out_queues, self._ray_pg, + self._zmq_ctx, + self._handshake_stop, + self._zmq_handshake_socket, + self._handshake_thread, ) + @overload def generate( - self, *args: Any, **kwargs: dict[str, Any] + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: Literal[True], + ) -> Generator[OmniRequestOutput, None, None]: ... + + @overload + def generate( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: Literal[False] = False, + ) -> list[OmniRequestOutput]: ... + + def generate( + self, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: OmniSamplingParams | Sequence[OmniSamplingParams] | None = None, + *, + py_generator: bool = False, + use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]: """Generate outputs for the given prompts. @@ -504,12 +841,10 @@ def generate( Each stage will use OmniLLM or OmniDiffusion based on stage_type. Args: - *args: Variable length argument list. - - args[0]: Input prompts for generation. - - args[1]: Optional list of per-stage parameters. - **kwargs: Arbitrary keyword arguments. - - prompt: Input prompts for generation (if not in args). - - sampling_params_list: Optional list of per-stage parameters (if not in args). + prompts: Input prompt(s) for generation. + sampling_params_list: Optional list of per-stage parameters. + py_generator: Whether the returned result(s) are wrapped in a generator instead of a list. + use_tqdm: Whether to use tqdm progress bar Returns: List of OmniRequestOutput objects, one for each input prompt. @@ -519,40 +854,26 @@ def generate( Raises: ValueError: If sampling_params_list is None or has incorrect length. """ - prompts = args[0] if args else kwargs.get("prompts") - sampling_params_list = args[1] if len(args) > 1 else kwargs.get("sampling_params_list") - py_generator = kwargs.get("py_generator", False) - if prompts is None: - if kwargs.get("prompt") is None: - raise ValueError("prompts is required for generation") - prompts = kwargs.get("prompt") - if sampling_params_list is None: - # For Omni LLM, the params are parsed via the yaml file. For the current version, - # diffusion params can parsed via the command line. - omni_params_kwargs = { - k: v for k, v in kwargs.items() if k not in ["prompt", "request_id", "output_modalities"] - } - - per_stage_params: list[Any] = [] - for stage_id, stage in enumerate(self.stage_list): - stage_type = getattr(stage, "stage_type", "llm") - if stage_type == "diffusion": - default_dict = self.default_sampling_params_list[stage_id] - # Merge user-provided kwargs - merged = {**default_dict, **omni_params_kwargs} - # Diffusion only needs to keep diff params, will be used via OmniDiffusionRequest - per_stage_params.append(merged) + sampling_params_list = self.default_sampling_params_list + elif not isinstance(sampling_params_list, Sequence): + # TODO: After the recent introduction of BAGEL model (one LLM and one Diffusion), + # expect the text_to_image example code to run when only passing one OmniDiffusionSamplingParams + # This behavior may be confusing, and future PR can improve it. + per_stage_params: list[OmniSamplingParams] = [] + for default_stage_sp in self.default_sampling_params_list: + default_sp_type = default_stage_sp.__class__ + if default_sp_type == sampling_params_list.__class__: + per_stage_params.append(sampling_params_list) else: - # LLM directly constructs SamplingParams, don't use the merged params - per_stage_params.append(self.default_sampling_params_list[stage_id]) - + per_stage_params.append(default_stage_sp) sampling_params_list = per_stage_params + try: if py_generator: return self._run_generation_with_generator(prompts, sampling_params_list) else: - outputs = list(self._run_generation(prompts, sampling_params_list)) + outputs = list(self._run_generation(prompts, sampling_params_list, use_tqdm)) return outputs except Exception as e: logger.exception("[Orchestrator] Failed to run generation: %s", e) @@ -562,8 +883,8 @@ def generate( def _run_generation_with_generator( self, - prompts: PromptType | Sequence[PromptType] | OmniDiffusionRequest | Sequence[OmniDiffusionRequest], - sampling_params_list: Any | Sequence[Any] | None, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: Sequence[OmniSamplingParams], ) -> Generator[OmniRequestOutput, None, None]: """Run generation through all stages in the pipeline and return a generator.""" gen = self._run_generation(prompts, sampling_params_list) @@ -578,8 +899,8 @@ def _run_generation_with_generator( def _run_generation( self, - prompts: PromptType | Sequence[PromptType] | OmniDiffusionRequest | Sequence[OmniDiffusionRequest], - sampling_params_list: Any | Sequence[Any] | None = None, + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params_list: Sequence[OmniSamplingParams], use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None]: """Run generation through all stages in the pipeline.""" @@ -587,18 +908,20 @@ def _run_generation( if sampling_params_list is None: raise ValueError("sampling_params_list is required for pipelined generation") - # Normalize sampling_params_list to a list - if not isinstance(sampling_params_list, (list, tuple)): - sampling_params_list = [sampling_params_list] - else: - sampling_params_list = list(sampling_params_list) - if len(sampling_params_list) != len(self.stage_list): raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + for i, (stage, sp) in enumerate(zip(self.stage_list, sampling_params_list)): + ExpectedSPType = OmniDiffusionSamplingParams if stage.stage_type == "diffusion" else SamplingParams + if not isinstance(sp, ExpectedSPType): + raise ValueError( + f"Expected sampling parameters with type {ExpectedSPType} in stage {i}, got {sp.__class__}" + ) + # Normalize prompts to a list for per-request iteration - if not isinstance(prompts, (list, tuple)): - request_prompts: list[PromptType] = [prompts] + # str is also Sequence but only test list-like containers here + if isinstance(prompts, str) or not isinstance(prompts, Sequence): + request_prompts: list[OmniPromptType] = [prompts] else: request_prompts = list(prompts) @@ -606,8 +929,8 @@ def _run_generation( num_stages = len(self.stage_list) # Generate globally unique request IDs and map them to original prompts - request_ids: list[str] = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] - request_id_to_prompt: dict[str, PromptType] = {rid: p for rid, p in zip(request_ids, request_prompts)} + request_ids = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] + request_id_to_prompt = {rid: p for rid, p in zip(request_ids, request_prompts)} # Track per-request start time for end-to-end timing _req_start_ts: dict[str, float] = {} @@ -626,10 +949,11 @@ def _run_generation( final_stage_id_to_prompt[rid] = final_stage_id_for_e2e # Metrics/aggregation helper - metrics = OrchestratorMetrics( + metrics = OrchestratorAggregator( num_stages, - self._enable_stats, + self.log_stats, _wall_start_ts, + final_stage_id_to_prompt, ) it = request_id_to_prompt.items() @@ -696,11 +1020,15 @@ def _run_generation( # Mark last output time for this stage whenever we receive outputs metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) try: - _m = result.get("metrics") + _m: StageRequestStats = result.get("metrics") if _m is not None: - if not isinstance(_m, dict): - _m = asdict(_m) - metrics.on_stage_metrics(stage_id, req_id, _m) + # Accumulate generation time + metrics.accumulated_gen_time_ms[req_id][stage_id] += _m.stage_gen_time_ms + + # For diffusion stages, we also accumulate diffusion time + metrics.accumulate_diffusion_metrics(stage.stage_type, req_id, engine_outputs) + + metrics.on_stage_metrics(stage_id, req_id, _m, stage.final_output_type) if pbar: elapsed = pbar.format_dict["elapsed"] or 1e-6 # Aggregate total tokens/images across all stages @@ -737,8 +1065,7 @@ def _run_generation( # End-to-end timing and time-per-token for final output # (only once per request at the designated final stage) try: - rid_key = str(req_id) - if stage_id == final_stage_id_to_prompt[req_id] and rid_key not in metrics.e2e_done: + if stage_id == final_stage_id_to_prompt[req_id]: metrics.on_finalize_request( stage_id, req_id, @@ -748,17 +1075,43 @@ def _run_generation( logger.exception( f"[{self._name}] Finalize request handling error for req {req_id} at stage {stage_id}: {e}", ) - yield OmniRequestOutput( + output_to_yield = OmniRequestOutput( stage_id=stage_id, final_output_type=stage.final_output_type, # type: ignore[attr-defined] request_output=engine_outputs, ) + # Record audio generated frames (only when finished) + try: + finished = ( + engine_outputs.finished + if hasattr(engine_outputs, "finished") + else ( + engine_outputs[0].finished + if isinstance(engine_outputs, list) + and engine_outputs + and hasattr(engine_outputs[0], "finished") + else False + ) + ) + if finished: + metrics.record_audio_generated_frames(output_to_yield, stage_id, req_id) + except Exception as e: + logger.exception( + f"[{self._name}] Failed to record audio metrics for req {req_id} at stage {stage_id}: {e}", + ) + + yield output_to_yield + next_stage_id = stage_id + 1 if next_stage_id <= final_stage_id_to_prompt[req_id]: next_stage: OmniStage = self.stage_list[next_stage_id] try: - next_inputs = next_stage.process_engine_inputs(self.stage_list, [request_id_to_prompt[req_id]]) + # Derive inputs for the next stage, record preprocess time + with metrics.stage_postprocess_timer(stage_id, req_id): + next_inputs = next_stage.process_engine_inputs( + self.stage_list, [request_id_to_prompt[req_id]] + ) except Exception as e: logger.exception( f"[{self._name}] Process engine inputs error for req {req_id}" @@ -812,8 +1165,7 @@ def _run_generation( # Summarize and print stats try: - summary = metrics.build_and_log_summary(final_stage_id_to_prompt) - logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) + metrics.build_and_log_summary() except Exception as e: logger.exception(f"[{self._name}] Failed to build/log summary: {e}") diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index e27e4eff635..0163f8988f3 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -1,35 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import logging -from dataclasses import fields +import uuid +from collections.abc import Sequence -from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest - -# TODO configure logging properly -logging.basicConfig(level=logging.INFO) - -logger = init_logger(__name__) - - -def prepare_requests(prompt: str | list[str], **kwargs): - field_names = {f.name for f in fields(OmniDiffusionRequest)} - - init_kwargs = {"prompt": prompt} - - for key, value in kwargs.items(): - if key in field_names: - init_kwargs[key] = value - - if "guidance_scale" in kwargs: - init_kwargs["guidance_scale_provided"] = True - - return OmniDiffusionRequest(**init_kwargs) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType +from vllm_omni.outputs import OmniRequestOutput class OmniDiffusion: @@ -44,79 +25,97 @@ class OmniDiffusion: """ def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs): + # Capture stage info from kwargs before they might be filtered out + stage_id = kwargs.get("stage_id") + engine_input_source = kwargs.get("engine_input_source") + if od_config is None: od_config = OmniDiffusionConfig.from_kwargs(**kwargs) elif isinstance(od_config, dict): + # If config is dict, check it too (priority to kwargs if both exist) + if stage_id is None: + stage_id = od_config.get("stage_id") + if engine_input_source is None: + engine_input_source = od_config.get("engine_input_source") od_config = OmniDiffusionConfig.from_kwargs(**od_config) self.od_config = od_config + # Inject stage info into omni_kv_config if present + if stage_id is not None: + self.od_config.omni_kv_config.setdefault("stage_id", stage_id) + if engine_input_source is not None: + self.od_config.omni_kv_config.setdefault("engine_input_source", engine_input_source) + + # Detect model class and load config # Diffusers-style models expose `model_index.json` with `_class_name`. - # Bagel models (and other non-diffusers) typically expose `config.json`. + # Non-diffusers models (e.g. Bagel, NextStep, GLM-Image) only have `config.json`, + # so we fall back to reading that and mapping model_type manually. try: config_dict = get_hf_file_to_dict( "model_index.json", od_config.model, ) - od_config.model_class_name = config_dict.get("_class_name", None) - od_config.update_multimodal_support() + if config_dict is not None: + od_config.model_class_name = config_dict.get("_class_name", None) + od_config.update_multimodal_support() - tf_config_dict = get_hf_file_to_dict( - "transformer/config.json", - od_config.model, - ) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) - except (AttributeError, OSError, ValueError): + tf_config_dict = get_hf_file_to_dict( + "transformer/config.json", + od_config.model, + ) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + else: + raise FileNotFoundError("model_index.json not found") + except (AttributeError, OSError, ValueError, FileNotFoundError): cfg = get_hf_file_to_dict("config.json", od_config.model) if cfg is None: raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") + # Map model_type or architecture to pipeline class model_type = cfg.get("model_type") architectures = cfg.get("architectures") or [] + pipeline_class = None + # Bagel/NextStep models don't have a model_index.json, so we set the pipeline class name manually if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: - od_config.model_class_name = "BagelPipeline" - od_config.tf_model_config = TransformerConfig() - od_config.update_multimodal_support() - else: - raise + pipeline_class = "BagelPipeline" + elif model_type == "nextstep": + if od_config.model_class_name is None: + pipeline_class = "NextStep11Pipeline" + elif model_type == "glm-image" or "GlmImageForConditionalGeneration" in architectures: + pipeline_class = "GlmImagePipeline" + elif architectures and len(architectures) == 1: + pipeline_class = architectures[0] + + if pipeline_class is None: + raise ValueError(f"Unknown model type: {model_type}, architectures: {architectures}") + + od_config.model_class_name = pipeline_class + od_config.tf_model_config = TransformerConfig() + od_config.update_multimodal_support() self.engine: DiffusionEngine = DiffusionEngine.make_engine(od_config) def generate( self, - prompt: str | list[str], - **kwargs, - ): - prompts = [] - if isinstance(prompt, str): - prompts.append(prompt) - elif isinstance(prompt, list): - prompts.extend(prompt) + prompts: OmniPromptType | Sequence[OmniPromptType], + sampling_params: OmniDiffusionSamplingParams, + request_ids: list[str] = [], + ) -> list[OmniRequestOutput]: + if isinstance(prompts, (str, dict)): + prompts = [prompts] else: - raise ValueError("Prompt must be a string or a list of strings") - - requests: list[OmniDiffusionRequest] = [] + prompts = list(prompts) # Check if request_id is provided in kwargs - request_id = kwargs.get("request_id") - - for i, p in enumerate(prompts): - req_kwargs = kwargs.copy() - if request_id is None: - # Generate default ID consistent with OmniLLM: "{i}_{uuid}" - req_kwargs["request_id"] = f"{i}" - - requests.append( - prepare_requests( - p, - **req_kwargs, - ) - ) - logger.info(f"Prepared {len(requests)} requests for generation.") - return self._run_engine(requests) + if len(request_ids) < len(prompts): + request_ids.extend(f"{i + len(request_ids)}_{uuid.uuid4()}" for i in range(len(prompts) - len(request_ids))) + + request = OmniDiffusionRequest(prompts, sampling_params, request_ids) + return self._run_engine(request) - def _run_engine(self, requests: list[OmniDiffusionRequest]): - return self.engine.step(requests) + def _run_engine(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: + return self.engine.step(request) def close(self) -> None: self.engine.close() diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 05a48feee0e..ec9248e3041 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -1,12 +1,15 @@ +from collections.abc import Callable from typing import Any import cloudpickle from pydantic import ValidationError +from tqdm import tqdm # External library imports (vLLM) from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field from vllm.entrypoints.llm import LLM from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.usage.usage_lib import UsageContext from vllm.utils.counter import Counter @@ -19,6 +22,7 @@ from vllm_omni.engine.input_processor import OmniInputProcessor from vllm_omni.engine.output_processor import MultimodalOutputProcessor from vllm_omni.entrypoints.utils import ( + filter_dataclass_kwargs, load_stage_configs_from_model, load_stage_configs_from_yaml, resolve_model_config_path, @@ -78,7 +82,7 @@ def __init__( self.worker_backend = kwargs.get("worker_backend", "multi_process") self.ray_address = kwargs.get("ray_address", None) self.batch_timeout = batch_timeout - self._enable_stats: bool = bool(log_stats) + self.log_stats: bool = bool(log_stats) # Load stage configurations if stage_configs_path is None: @@ -118,6 +122,9 @@ def __init__( ) raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e + # Extract omni_kv_config from kwargs if present (injected by Omni) + omni_kv_config = kwargs.pop("omni_kv_config", None) + if compilation_config is not None: if isinstance(compilation_config, int): compilation_config_instance = CompilationConfig(level=compilation_config) @@ -144,7 +151,8 @@ def __init__( model=model, compilation_config=compilation_config_instance, structured_outputs_config=structured_outputs_instance, - **kwargs, + omni_kv_config=omni_kv_config, + **filter_dataclass_kwargs(OmniEngineArgs, kwargs), ) # Create the Engine (autoselects V0 vs V1) @@ -154,9 +162,7 @@ def __init__( log_stats=self.llm_engine.log_stats, engine_core_output_type=engine_args.engine_output_type, ) - self.llm_engine.input_processor = OmniInputProcessor( - vllm_config=self.llm_engine.vllm_config, tokenizer=self.llm_engine.tokenizer - ) + self.llm_engine.input_processor = OmniInputProcessor(vllm_config=self.llm_engine.vllm_config) self.engine_class = type(self.llm_engine) self.request_counter = Counter() @@ -190,3 +196,47 @@ def __del__(self) -> None: # best-effort self.close() except Exception as e: logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True) + + def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[RequestOutput | PoolingRequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + pbar = tqdm_func( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), + ) + + # Run the engine. + outputs: list[RequestOutput | PoolingRequestOutput] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + n = len(output.outputs) + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) * n + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum(len(stp.token_ids) for stp in output.outputs) + out_spd = total_out_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s" + pbar.update(n) + else: + pbar.update(1) + if pbar.n == num_requests: + pbar.refresh() + + if use_tqdm: + pbar.close() + # Sort the outputs by the int part of request ID which is in format of 'int-uuid'. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index b98344aca1d..cf0fe709379 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -12,12 +12,21 @@ logger = logging.getLogger(__name__) -class OmniStageTaskType(enum.Enum): - GENERATE = "generate" - ABORT = "abort" - SHUTDOWN = "shutdown" - PROFILER_START = "profiler_start" - PROFILER_STOP = "profiler_stop" +def set_stage_devices( + stage_id: int, + devices: str | int | None, + device_type: str | None = None, +) -> str | None: + """Configure per-stage device visibility and current device (CUDA or NPU). + + This function sets environment variables that control which devices are visible + to the process. It must be called BEFORE worker initialization so that workers + see the correct devices. + + + NOTE: This will set the control variable for the appropriate platform. + - CUDA: CUDA_VISIBLE_DEVICES + - NPU: ASCEND_RT_VISIBLE_DEVICES SHUTDOWN_TASK = {"type": OmniStageTaskType.SHUTDOWN} @@ -288,3 +297,78 @@ def _to_dict(x: Any) -> dict[str, Any]: return dict(x) except Exception: return {} +import enum +import json +from multiprocessing import shared_memory as _shm + +import enum + +class OmniStageTaskType(enum.Enum): + GENERATE = "generate" + ABORT = "abort" + SHUTDOWN = "shutdown" + PROFILER_START = "profiler_start" + PROFILER_STOP = "profiler_stop" + + +SHUTDOWN_TASK = {"type": OmniStageTaskType.SHUTDOWN} + + +def is_profiler_task(task_type: OmniStageTaskType) -> bool: + return task_type in (OmniStageTaskType.PROFILER_START, OmniStageTaskType.PROFILER_STOP) + +def maybe_dump_to_shm(obj, threshold: int) -> tuple: + """Dump object to SHM if serialized size exceeds threshold.""" + payload = serialize_obj(obj) + if len(payload) > threshold: + return True, shm_write_bytes(payload, name=None) + return False, obj + +def _resolve_model_tokenizer_paths( + model: str, + engine_args: dict, +) -> str: + """Resolve model and tokenizer paths for non-standard directory structures. + + Some models (e.g., GLM-Image) have tokenizer in root and model in subdirectory. + This function handles model_subdir and tokenizer_subdir engine_args. + + Args: + model: Base model path + engine_args: Engine arguments (modified in-place to remove subdir args + and set tokenizer if needed) + + Returns: + Resolved model path (may be subdirectory of original) + """ + import os + + model_subdir = engine_args.pop("model_subdir", None) + tokenizer_subdir = engine_args.pop("tokenizer_subdir", None) + base_model_path = model + + if model_subdir: + model = os.path.join(model, model_subdir) + logger.info(f"Using model subdirectory: {model}") + + if tokenizer_subdir is not None: + tokenizer_path = os.path.join(base_model_path, tokenizer_subdir) if tokenizer_subdir else base_model_path + engine_args["tokenizer"] = tokenizer_path + logger.info(f"Using tokenizer from: {tokenizer_path}") + elif model_subdir and "tokenizer" not in engine_args: + engine_args["tokenizer"] = base_model_path + logger.info(f"Using tokenizer from base model path: {base_model_path}") + + return model + +def maybe_load_from_ipc(container: dict[str, Any], obj_key: str, shm_key: str) -> Any: + """Load object from container that may carry SHM or inline object. + + Deprecated: prefer `maybe_load_from_ipc_with_metrics` to also obtain + decode-time and size metrics. + """ + if shm_key in container: + from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer + + return OmniSerializer.deserialize(shm_read_bytes(container[shm_key])) + return container[obj_key] diff --git a/vllm_omni/entrypoints/zmq_utils.py b/vllm_omni/entrypoints/zmq_utils.py new file mode 100644 index 00000000000..2ef5685cdaa --- /dev/null +++ b/vllm_omni/entrypoints/zmq_utils.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""ZMQ-based queue utilities for Omni IPC.""" + +from __future__ import annotations + +import queue +from typing import Any + +import zmq +from vllm.utils.network_utils import make_zmq_socket + + +class ZmqQueue: + """Queue-like wrapper on a ZMQ socket.""" + + def __init__( + self, + ctx: zmq.Context, + socket_type: int, + *, + bind: str | None = None, + connect: str | None = None, + recv_timeout_ms: int | None = None, + send_timeout_ms: int | None = None, + ) -> None: + # Determine path and bind mode + path = bind if bind is not None else connect + if path is None: + raise ValueError("Either bind or connect must be specified") + bind_mode = bind is not None + + self._socket = make_zmq_socket(ctx, path, socket_type, bind=bind_mode, linger=5000) + + # Reusable poller for efficient polling operations + self._poller = zmq.Poller() + self._poller.register(self._socket, zmq.POLLIN) + + # Store default timeout settings + self._default_recv_timeout = recv_timeout_ms + self._default_send_timeout = send_timeout_ms + + # Apply timeout settings if specified + if recv_timeout_ms is not None: + self._socket.rcvtimeo = recv_timeout_ms + if send_timeout_ms is not None: + self._socket.sndtimeo = send_timeout_ms + + self.endpoint = path + + def put(self, obj: Any) -> None: + """Send an object to the queue. Blocks until sent or timeout.""" + try: + self._socket.send_pyobj(obj) + except zmq.Again as e: + raise queue.Full() from e + + def put_nowait(self, obj: Any) -> None: + """Send an object to the queue without blocking.""" + try: + self._socket.send_pyobj(obj, flags=zmq.NOBLOCK) + except zmq.Again as e: + raise queue.Full() from e + + def get(self, timeout: float | None = None) -> Any: + """Receive an object from the queue with optional timeout in seconds.""" + if timeout is None: + return self._socket.recv_pyobj() + + # Use the reusable poller for timeout handling + events = dict(self._poller.poll(int(timeout * 1000))) + if events.get(self._socket) == zmq.POLLIN: + return self._socket.recv_pyobj() + raise queue.Empty() + + def get_nowait(self) -> Any: + """Receive an object from the queue without blocking.""" + try: + return self._socket.recv_pyobj(flags=zmq.NOBLOCK) + except zmq.Again as e: + raise queue.Empty() from e + + def empty(self) -> bool: + """Check if the queue is empty without blocking.""" + events = dict(self._poller.poll(0)) + return events.get(self._socket) != zmq.POLLIN + + def close(self) -> None: + self._socket.close(0) + + +def create_zmq_queue(ctx: zmq.Context, endpoint: str, socket_type: int) -> ZmqQueue: + """Create a ZmqQueue from an endpoint string and socket type.""" + return ZmqQueue(ctx, socket_type, connect=endpoint) diff --git a/vllm_omni/model_executor/models/hcx_omni/__init__.py b/vllm_omni/model_executor/models/hcx_omni/__init__.py new file mode 100644 index 00000000000..50793fc6936 --- /dev/null +++ b/vllm_omni/model_executor/models/hcx_omni/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""HyperCLOVAX-SEED-Omni-8B models for vLLM-Omni.""" diff --git a/vllm_omni/model_executor/models/hcx_omni/hcx_omni.py b/vllm_omni/model_executor/models/hcx_omni/hcx_omni.py new file mode 100644 index 00000000000..9a1ab520499 --- /dev/null +++ b/vllm_omni/model_executor/models/hcx_omni/hcx_omni.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""HyperCLOVAX-SEED-Omni-8B multi-stage model dispatcher. + +Architecture overview +--------------------- +HyperCLOVAX-SEED-Omni-8B is a 3-stage omni model: + + Stage 0 – Thinker (this module, LLM engine) + Input : text + optional image/audio + Output: text tokens + discrete audio codes (128606–135167) + + discrete vision codes (135168+) + Config: engine_output_type = "latent" + + Stage 1 – Vision Decoder (diffusion engine) + Input : 729 discrete vision codes from stage 0 + Output: generated image (PNG / JPEG) + Config: model_class_name = "HyperCLOVAXVisionPipeline" + + Stage 2 – Audio Decoder (diffusion engine) + Input : N discrete audio codes from stage 0 + Output: 24 kHz waveform (WAV / PCM) + Config: model_class_name = "HyperCLOVAXAudioPipeline" + +Stages 1 and 2 are handled by the vLLM-Omni *diffusion* engine and do +**not** go through this LLM model registry. This dispatcher exists +only for stage 0 so that the standard ``model_arch`` routing works. +""" +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.hcx_omni import ( + HCXOmniDummyInputsBuilder, + HCXOmniForCausalLM, + HCXOmniMultiModalProcessor, + HCXOmniProcessingInfo, +) +from vllm.model_executor.models.interfaces import ( + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +@MULTIMODAL_REGISTRY.register_processor( + HCXOmniMultiModalProcessor, + info=HCXOmniProcessingInfo, + dummy_inputs=HCXOmniDummyInputsBuilder, +) +class HCXOmniForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsMRoPE, + SupportsPP, + SupportsQuant, +): + """Top-level HyperCLOVAX-SEED-Omni-8B model for vLLM-Omni. + + This class is the ``model_arch`` entry point for the thinker stage. + It delegates all logic to :class:`~vllm.model_executor.models.hcx_omni. + HCXOmniForCausalLM` from the vLLM base repository. + + The vision decoder and audio decoder stages use ``model_class_name`` + (diffusion engine) and therefore do not require an entry here. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + self._model = HCXOmniForCausalLM(vllm_config=vllm_config, prefix=prefix) + + # ------------------------------------------------------------------ # + # Delegate interface implementations to the inner model # + # ------------------------------------------------------------------ # + + @property + def config(self): + return self._model.config + + # SupportsMRoPE + def get_mrope_input_positions(self, *args: Any, **kwargs: Any): + return self._model.get_mrope_input_positions(*args, **kwargs) + + def iter_mm_grid_thw(self, *args: Any, **kwargs: Any): + return self._model.iter_mm_grid_thw(*args, **kwargs) + + # SupportsMultiModal + def get_multimodal_embeddings(self, **kwargs: Any): + return self._model.get_multimodal_embeddings(**kwargs) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + ) -> torch.Tensor: + return self._model.get_input_embeddings(input_ids, multimodal_embeddings) + + # SupportsPP + def make_empty_intermediate_tensors(self, *args: Any, **kwargs: Any): + return self._model.make_empty_intermediate_tensors(*args, **kwargs) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + return self._model.forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + def compute_logits( + self, hidden_states: torch.Tensor + ) -> torch.Tensor | None: + return self._model.compute_logits(hidden_states) + + def load_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> set[str]: + return self._model.load_weights(weights) + + def get_mm_mapping(self): + return self._model.get_mm_mapping() diff --git a/vllm_omni/model_executor/models/hcx_omni/hcx_omni_thinker.py b/vllm_omni/model_executor/models/hcx_omni/hcx_omni_thinker.py new file mode 100644 index 00000000000..ebce44c4e4e --- /dev/null +++ b/vllm_omni/model_executor/models/hcx_omni/hcx_omni_thinker.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Thin wrapper around the vLLM base HCXOmniForCausalLM thinker. + +Registers the multimodal processor for the vLLM-Omni pipeline context +and exposes all interfaces required by the thinker stage. +""" +from collections.abc import Iterable, Mapping +from typing import Any + +import torch +from torch import nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.hcx_omni import ( + HCXOmniForCausalLM, + HCXOmniMultiModalProcessor, + HCXOmniProcessingInfo, + HCXOmniDummyInputsBuilder, +) +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.parse import MultiModalDataItems +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +@MULTIMODAL_REGISTRY.register_processor( + HCXOmniMultiModalProcessor, + info=HCXOmniProcessingInfo, + dummy_inputs=HCXOmniDummyInputsBuilder, +) +class HCXOmniThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsMRoPE, + SupportsPP, + SupportsQuant, +): + """Thinker stage model for HyperCLOVAX-SEED-Omni-8B. + + This is a thin wrapper around :class:`HCXOmniForCausalLM` (defined in + the vLLM base repository) that registers the multimodal processor and + exposes the standard vLLM model interfaces needed by the omni pipeline + thinker stage. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + self._model = HCXOmniForCausalLM(vllm_config=vllm_config, prefix=prefix) + + # --- delegate all interface methods to inner model ------------------- # + + @property + def config(self): + return self._model.config + + @property + def language_model(self): + return self._model.language_model + + @property + def visual(self): + return self._model.visual + + @property + def audio_tower(self): + return self._model.audio_tower + + # SupportsMRoPE + def get_mrope_input_positions(self, *args, **kwargs): + return self._model.get_mrope_input_positions(*args, **kwargs) + + def iter_mm_grid_thw(self, *args, **kwargs): + return self._model.iter_mm_grid_thw(*args, **kwargs) + + # SupportsMultiModal + def get_multimodal_embeddings( + self, **kwargs: Any + ) -> MultiModalEmbeddings | None: + return self._model.get_multimodal_embeddings(**kwargs) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + ) -> torch.Tensor: + return self._model.get_input_embeddings(input_ids, multimodal_embeddings) + + # SupportsPP + def make_empty_intermediate_tensors(self, *args, **kwargs): + return self._model.make_empty_intermediate_tensors(*args, **kwargs) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + return self._model.forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + return self._model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + return self._model.load_weights(weights) + + def get_mm_mapping(self): + return self._model.get_mm_mapping() + + # SupportsQuant + def get_quant_config(self): + return getattr(self._model, "get_quant_config", lambda: None)() diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 56bceae41ab..c8a0fb6acaa 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -48,6 +48,118 @@ "qwen3_omni_code2wav", "Qwen3OmniMoeCode2Wav", ), + "CosyVoice3Model": ( + "cosyvoice3", + "cosyvoice3", + "CosyVoice3Model", + ), + "MammothModa2Qwen2ForCausalLM": ( + "mammoth_moda2", + "mammoth_moda2", + "MammothModa2Qwen2ForCausalLM", + ), + "MammothModa2ARForConditionalGeneration": ( + "mammoth_moda2", + "mammoth_moda2", + "MammothModa2ARForConditionalGeneration", + ), + "MammothModa2DiTPipeline": ( + "mammoth_moda2", + "pipeline_mammothmoda2_dit", + "MammothModa2DiTPipeline", + ), + "MammothModa2ForConditionalGeneration": ( + "mammoth_moda2", + "mammoth_moda2", + "MammothModa2ForConditionalGeneration", + ), + "Mammothmoda2Model": ( + "mammoth_moda2", + "mammoth_moda2", + "MammothModa2ForConditionalGeneration", + ), + "Qwen3TTSForConditionalGeneration": ( + "qwen3_tts", + "qwen3_tts_talker", + "Qwen3TTSTalkerForConditionalGeneration", + ), + "Qwen3TTSTalkerForConditionalGeneration": ( + "qwen3_tts", + "qwen3_tts_talker", + "Qwen3TTSTalkerForConditionalGeneration", + ), + "Qwen3TTSCode2Wav": ( + "qwen3_tts", + "qwen3_tts_code2wav", + "Qwen3TTSCode2Wav", + ), + ## mimo_audio + "MiMoAudioModel": ( + "mimo_audio", + "mimo_audio", + "MiMoAudioForConditionalGeneration", + ), + "MiMoAudioLLMModel": ( + "mimo_audio", + "mimo_audio_llm", + "MiMoAudioLLMForConditionalGeneration", + ), + "MiMoAudioToken2WavModel": ( + "mimo_audio", + "mimo_audio_code2wav", + "MiMoAudioToken2WavForConditionalGenerationVLLM", + ), + ## glm_image + "GlmImageForConditionalGeneration": ( + "glm_image", + "glm_image_ar", + "GlmImageForConditionalGeneration", + ), + "OmniBagelForConditionalGeneration": ( + "bagel", + "bagel", + "OmniBagelForConditionalGeneration", + ), + "HunyuanImage3ForCausalMM": ( + "hunyuan_image3", + "hunyuan_image3", + "HunyuanImage3ForConditionalGeneration", + ), + ## fish_speech (Fish Speech S2 Pro) + "FishSpeechSlowARForConditionalGeneration": ( + "fish_speech", + "fish_speech_slow_ar", + "FishSpeechSlowARForConditionalGeneration", + ), + "FishSpeechDACDecoder": ( + "fish_speech", + "fish_speech_dac_decoder", + "FishSpeechDACDecoder", + ), + ## Voxtral TTS + "VoxtralTTSForConditionalGeneration": ( + "voxtral_tts", + "voxtral_tts", + "VoxtralTTSForConditionalGeneration", + ), + "VoxtralTTSAudioGeneration": ( + "voxtral_tts", + "voxtral_tts_audio_generation", + "VoxtralTTSAudioGenerationForConditionalGeneration", + ), + "VoxtralTTSAudioTokenizer": ("voxtral_tts", "voxtral_tts_audio_tokenizer", "VoxtralTTSAudioTokenizer"), + ## HyperCLOVAX-SEED-Omni-8B + # stage 0 (thinker LLM) — stages 1/2 use DiffusionModelRegistry via model_class_name + "HCXVisionV2ForCausalLM": ( + "hcx_omni", + "hcx_omni", + "HCXOmniForConditionalGeneration", + ), + "HCXOmniForCausalLM": ( + "hcx_omni", + "hcx_omni", + "HCXOmniForConditionalGeneration", + ), } _VLLM_OMNI_MODELS = { diff --git a/vllm_omni/model_executor/stage_configs/hcx_omni.yaml b/vllm_omni/model_executor/stage_configs/hcx_omni.yaml new file mode 100644 index 00000000000..6a76fa2791d --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hcx_omni.yaml @@ -0,0 +1,102 @@ +runtime: + connectors: + shared_memory_connector: + extra: + shm_threshold_bytes: 65536 + name: SharedMemoryConnector + defaults: + max_inflight: 8 + window_size: -1 + edges: + - from: 0 + to: 1 + window_size: -1 + - from: 0 + to: 2 + window_size: -1 + - from: 1 + to: 2 + window_size: -1 + enabled: true +stage_args: +- default_sampling_params: + detokenize: true + max_tokens: 2048 + repetition_penalty: 1.0 + seed: 42 + temperature: 0.1 + top_k: -1 + top_p: 1.0 + engine_args: + enable_prefix_caching: false + enforce_eager: true + engine_output_type: latent + gpu_memory_utilization: 0.15 + limit_mm_per_prompt: + audio: 1 + image: 1 + max_model_len: 8192 + max_num_seqs: 8 + model_arch: HCXVisionV2ForCausalLM + model_stage: thinker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + tensor_parallel_size: 4 + trust_remote_code: true + worker_type: ar + final_output: true + final_output_type: text + is_comprehension: true + runtime: + devices: 0,1,2,3 + max_batch_size: 8 + process: true + stage_id: 0 + stage_type: llm +- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni.thinker2vision_decoder + default_sampling_params: + guidance_scale: 0.0 + height: 768 + num_inference_steps: 50 + seed: 42 + width: 768 + engine_args: + distributed_executor_backend: mp + enforce_eager: true + engine_output_type: image + gpu_memory_utilization: 0.75 + model_class_name: HyperCLOVAXVisionPipeline + model_stage: decoder/vision + model_subdir: decoder/vision + trust_remote_code: true + engine_input_source: + - 0 + final_output: true + final_output_type: image + runtime: + devices: '4' + max_batch_size: 1 + process: true + stage_id: 1 + stage_type: diffusion +- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hyperclovax_seed_omni.thinker2audio_decoder + default_sampling_params: + seed: 42 + engine_args: + distributed_executor_backend: mp + enforce_eager: true + engine_output_type: audio + gpu_memory_utilization: 0.4 + model_class_name: HyperCLOVAXAudioPipeline + model_stage: decoder/audio + model_subdir: decoder/audio/NCZSCosybigvganDecoder.mar + trust_remote_code: true + engine_input_source: + - 0 + final_output: true + final_output_type: audio + runtime: + devices: '5' + max_batch_size: 1 + process: true + stage_id: 2 + stage_type: diffusion diff --git a/vllm_omni/model_executor/stage_input_processors/hyperclovax_seed_omni.py b/vllm_omni/model_executor/stage_input_processors/hyperclovax_seed_omni.py new file mode 100644 index 00000000000..e0e31534f57 --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/hyperclovax_seed_omni.py @@ -0,0 +1,156 @@ +"""Stage input processors for HyperCLOVAX-SEED-Omni-8B pipeline. + +The thinker generates a mixed token sequence containing: + - Regular text tokens (< 128606) + - Discrete audio tokens (128606 ~ 135167) + - Discrete vision tokens (135168 ~ 135168+255) + +These processors extract the relevant discrete tokens and route them +to the appropriate decoder stage. +""" + +import torch +from vllm.inputs import TextPrompt + +from vllm_omni.inputs.data import OmniTokensPrompt + +# Token ID boundaries from config.json +DISCRETE_AUDIO_UNIT_0_ID = 128606 +DISCRETE_IMAGE_UNIT_0_ID = 135168 +DISCRETE_AUDIO_VOCAB_SIZE = 6561 # CosyVoice2 FSQ codebook +DISCRETE_IMAGE_VOCAB_SIZE = 65536 # TA-Tok SimVQ codebook (2^16) +DISCRETE_IMAGE_TOKEN_LENGTH = 729 # 27x27 latent tokens per image + + +def _extract_discrete_tokens( + token_ids: list[int], start_id: int, vocab_size: int +) -> list[int]: + """Extract and remap discrete tokens from a mixed token sequence. + + Returns tokens remapped to [0, vocab_size) range. + """ + return [ + tid - start_id + for tid in token_ids + if start_id <= tid < start_id + vocab_size + ] + + +def thinker2vision_decoder( + stage_list, + engine_input_source, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +): + """Extract discrete vision tokens from thinker output → vision decoder. + + The vision decoder (HyperCLOVAXVisionPipeline) takes 256 discrete codes + per image and converts them to pixel images via diffusion. + """ + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + source_stage_id = engine_input_source[0] + thinker_outputs = stage_list[source_stage_id].engine_outputs + if thinker_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + + vision_decoder_inputs = [] + for thinker_output in thinker_outputs: + # Text-only iterations can produce an empty outputs list. + if not thinker_output.outputs: + continue + output = thinker_output.outputs[0] + output_token_ids = list(output.token_ids) + vision_codes = _extract_discrete_tokens( + output_token_ids, DISCRETE_IMAGE_UNIT_0_ID, DISCRETE_IMAGE_VOCAB_SIZE + ) + + if not vision_codes: + continue + + # Truncate/pad to exact DISCRETE_IMAGE_TOKEN_LENGTH (27x27=729). + # The LLM may generate slightly more or fewer tokens than expected; + # the vision decoder rearranges as (h w) → (h, w) so the length must be + # a perfect square == DISCRETE_IMAGE_TOKEN_LENGTH. + vision_codes = vision_codes[:DISCRETE_IMAGE_TOKEN_LENGTH] + vision_codes += [0] * (DISCRETE_IMAGE_TOKEN_LENGTH - len(vision_codes)) + + # Pipeline expects vision_tokens key in req.extra + vision_decoder_inputs.append( + OmniTokensPrompt( + prompt_token_ids=vision_codes, + additional_information={ + "request_id": thinker_output.request_id, + "vision_tokens": vision_codes, + "num_images": 1, + }, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return vision_decoder_inputs + + +def thinker2audio_decoder( + stage_list, + engine_input_source, + prompt: OmniTokensPrompt | TextPrompt = None, + requires_multimodal_data: bool = False, +): + """Extract discrete audio tokens from thinker output → audio decoder. + + The audio decoder (Unit-BigVGAN) takes discrete audio codes (6561 vocab) + and converts them to 24kHz waveforms. + """ + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + source_stage_id = engine_input_source[0] + thinker_outputs = stage_list[source_stage_id].engine_outputs + if thinker_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + + audio_decoder_inputs = [] + for thinker_output in thinker_outputs: + # Text-only iterations can produce an empty outputs list. + if not thinker_output.outputs: + continue + output = thinker_output.outputs[0] + output_token_ids = list(output.token_ids) + + audio_codes = _extract_discrete_tokens( + output_token_ids, DISCRETE_AUDIO_UNIT_0_ID, DISCRETE_AUDIO_VOCAB_SIZE + ) + + if not audio_codes: + continue + + # Pipeline expects audio_tokens as list[list[int]] (batch), + # speakers as list[str], formats as list[str], and optional + # ref_audio_tokens for zero-shot TTS (ECAPA-TDNN speaker embedding). + # ref_audio_b64 is the raw base64 audio from the user's input message, + # injected by serving_chat.py into the engine_prompt dict. + _ref = None + if isinstance(prompt, dict): + _ref = prompt.get("ref_audio_b64") + elif isinstance(prompt, list) and prompt: + _p = prompt[0] + if isinstance(_p, dict): + _ref = _p.get("ref_audio_b64") + audio_decoder_inputs.append( + OmniTokensPrompt( + prompt_token_ids=audio_codes, + additional_information={ + "request_id": thinker_output.request_id, + "audio_tokens": [audio_codes], + "speakers": ["fkms"], + "ref_audio_tokens": [_ref], + }, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return audio_decoder_inputs From 936460fa119e67cdcdac6acc531142593ac655a6 Mon Sep 17 00:00:00 2001 From: kje Date: Mon, 6 Apr 2026 09:18:25 +0900 Subject: [PATCH 2/4] fix: async fan-out topology, serving pipeline, and vLLM 0.18.0 compat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - async_omni.py: redesign _process_sequential_results for fan-out topology — Stage-0 forwards to Stage-1 (vision) AND Stage-2 (audio) independently based on engine_input_source; add skipped_stages for conditional routing - serving_chat.py: add _stage0_is_llm guard so GLM-Image bare-text replacement does not clobber HCX Omni Stage-0 multimodal inputs; handle audio output in _create_chat_completion_response - async_omni_diffusion.py, omni_stage.py: vLLM 0.18.0 API alignment - worker/gpu_ar_model_runner.py, async_omni_llm.py: compatibility fixes Co-Authored-By: 길재은 Co-Authored-By: Hyunjoon Jeong --- vllm_omni/entrypoints/async_omni.py | 571 ++++++---- vllm_omni/entrypoints/async_omni_diffusion.py | 412 +++++-- vllm_omni/entrypoints/async_omni_llm.py | 22 +- vllm_omni/entrypoints/omni_stage.py | 1006 +++++++++-------- vllm_omni/entrypoints/openai/serving_chat.py | 816 ++++++++----- .../qwen2_5_omni/qwen2_5_omni_thinker.py | 28 +- .../qwen3_omni/qwen3_omni_moe_thinker.py | 26 +- vllm_omni/worker/gpu_ar_model_runner.py | 2 +- 8 files changed, 1829 insertions(+), 1054 deletions(-) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 3c275147fa0..04f493dd7e7 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -1,32 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import copy import time import weakref -from collections.abc import AsyncGenerator, Iterable -from dataclasses import asdict -from pprint import pformat +from collections.abc import AsyncGenerator, Iterable, Sequence from typing import Any from vllm.config import VllmConfig from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.plugins.io_processors import get_io_processor from vllm.sampling_params import SamplingParams from vllm.tokenizers import TokenizerLike from vllm.v1.engine.exceptions import EngineDeadError -# Internal imports (our code) from vllm_omni.config import OmniModelConfig from vllm_omni.diffusion.data import DiffusionParallelConfig -from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector +from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length, try_send_via_connector from vllm_omni.distributed.ray_utils.utils import try_close_ray from vllm_omni.engine.input_processor import OmniInputProcessor from vllm_omni.entrypoints.client_request_state import ClientRequestState -from vllm_omni.entrypoints.log_utils import ( - OrchestratorMetrics, -) from vllm_omni.entrypoints.omni import OmniBase from vllm_omni.entrypoints.omni_stage import OmniStage from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType @@ -34,12 +28,17 @@ from vllm_omni.entrypoints.utils import ( get_final_stage_id_for_e2e, ) +from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams + +# Internal imports (our code) +from vllm_omni.lora.request import LoRARequest +from vllm_omni.metrics import OrchestratorAggregator from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) -def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handler): +def _weak_close_cleanup_async(stage_list, stage_in_queues, stage_out_queues, ray_pg, output_handler, zmq_ctx=None): """Weak reference cleanup function for AsyncOmni instances.""" if stage_list: for q in stage_in_queues: @@ -47,6 +46,13 @@ def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handle q.put_nowait(SHUTDOWN_TASK) except Exception as e: logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() + for q in stage_out_queues: + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() for stage in stage_list: try: stage.stop_stage_worker() @@ -56,6 +62,8 @@ def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handle # Cancel output handler if output_handler is not None: output_handler.cancel() + if zmq_ctx is not None: + zmq_ctx.term() class AsyncOmni(OmniBase): @@ -65,10 +73,8 @@ class AsyncOmni(OmniBase): asynchronous LLM and Diffusion models. Args: - *args: Variable length argument list. - - args[0]: Model name or path to load. + model: Model name or path to load. **kwargs: Arbitrary keyword arguments. - - model: Model name or path to load (if not in args). - stage_configs_path: Optional path to YAML file containing stage configurations. If None, configurations are loaded from the model. - log_stats: Whether to enable statistics logging @@ -94,7 +100,7 @@ class AsyncOmni(OmniBase): ... print(output) """ - def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: + def __init__(self, model: str, **kwargs: dict[str, Any]) -> None: # Pause/resume control attributes self._pause_cond: asyncio.Condition = asyncio.Condition() self._paused: bool = False @@ -103,7 +109,7 @@ def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: self.request_states: dict[str, ClientRequestState] = {} self.output_handler: asyncio.Task | None = None - super().__init__(*args, **kwargs) + super().__init__(model, **kwargs) # Register weak reference cleanup (called on garbage collection) self._weak_finalizer = weakref.finalize( @@ -111,8 +117,10 @@ def __init__(self, *args: Any, **kwargs: dict[str, Any]) -> None: _weak_close_cleanup_async, self.stage_list, self._stage_in_queues, + self._stage_out_queues, self._ray_pg, self.output_handler, + self._zmq_ctx, ) def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[str, Any]: @@ -133,9 +141,20 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st sequence_parallel_size = kwargs.get("sequence_parallel_size") tensor_parallel_size = kwargs.get("tensor_parallel_size") or 1 cfg_parallel_size = kwargs.get("cfg_parallel_size") or 1 + use_hsdp = kwargs.get("use_hsdp", False) + hsdp_shard_size = kwargs.get("hsdp_shard_size", -1) + hsdp_replicate_size = kwargs.get("hsdp_replicate_size", 1) if sequence_parallel_size is None: sequence_parallel_size = ulysses_degree * ring_degree - num_devices = sequence_parallel_size * tensor_parallel_size * cfg_parallel_size + + # Calculate num_devices: consider standalone HSDP + other_parallel_size = sequence_parallel_size * tensor_parallel_size * cfg_parallel_size + if use_hsdp and other_parallel_size == 1 and hsdp_shard_size > 0: + # Standalone HSDP: num_devices is determined by HSDP dimensions + num_devices = hsdp_shard_size * hsdp_replicate_size + else: + num_devices = other_parallel_size + for i in range(1, num_devices): devices += f",{i}" parallel_config = DiffusionParallelConfig( @@ -146,6 +165,9 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st ulysses_degree=ulysses_degree, ring_degree=ring_degree, cfg_parallel_size=cfg_parallel_size, + use_hsdp=use_hsdp, + hsdp_shard_size=hsdp_shard_size, + hsdp_replicate_size=hsdp_replicate_size, ) default_stage_cfg = [ { @@ -162,8 +184,12 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st "vae_use_tiling": kwargs.get("vae_use_tiling", False), "cache_backend": cache_backend, "cache_config": cache_config, + "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False), "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), + "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False), "enforce_eager": kwargs.get("enforce_eager", False), + "diffusion_load_format": kwargs.get("diffusion_load_format", "default"), + "custom_pipeline_args": kwargs.get("custom_pipeline_args", None), }, "final_output": True, "final_output_type": "image", @@ -192,11 +218,10 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: if stage.vllm_config is not None and stage.tokenizer is not None: try: vllm_config = stage.vllm_config - tokenizer = stage.tokenizer # Initialize input_processor + # OMNI: OmniInputProcessor creates tokenizer internally from vllm_config self.input_processor = OmniInputProcessor( vllm_config=vllm_config, - tokenizer=tokenizer, ) # Initialize model_config self.model_config = vllm_config.model_config @@ -213,7 +238,45 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: logger.warning( f"[{self._name}] Failed to initialize processors from stage-{stage.stage_id}: {e}", ) - # If no LLM stage found, set processors to None + # If no LLM stage found via ZMQ payload, fall back to creating from stage.engine_args + if not hasattr(self, "input_processor") or self.input_processor is None: + for stage in self.stage_list: + if stage.stage_type == "llm" and hasattr(stage, "engine_args") and stage.engine_args is not None: + try: + logger.info( + f"[{self._name}] stage-{stage.stage_id} vllm_config not received via ZMQ, " + "falling back to create_engine_config from stage.engine_args" + ) + from vllm.usage.usage_lib import UsageContext + from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs + from vllm_omni.entrypoints.omni_stage import filter_dataclass_kwargs + try: + from omegaconf import OmegaConf + _ea = OmegaConf.to_container(stage.engine_args, resolve=True) + except Exception: + _ea = dict(stage.engine_args) + _ea = filter_dataclass_kwargs(AsyncOmniEngineArgs, _ea) + _ea.pop("model", None) + _model = getattr(self, "_model", None) + if _model is None: + raise RuntimeError("Cannot determine model path for fallback") + _omni_ea = AsyncOmniEngineArgs(model=_model, **_ea) + vllm_config = _omni_ea.create_engine_config( + usage_context=UsageContext.API_SERVER + ) + stage.set_vllm_config(vllm_config) + self.input_processor = OmniInputProcessor(vllm_config=vllm_config) + self.model_config = vllm_config.model_config + io_processor_plugin = self.model_config.io_processor_plugin + self.io_processor = get_io_processor(vllm_config, io_processor_plugin) + logger.info( + f"[{self._name}] Initialized input_processor from stage-{stage.stage_id} engine_args fallback" + ) + break + except Exception as e: + logger.warning( + f"[{self._name}] Fallback init failed for stage-{stage.stage_id}: {e}" + ) if not hasattr(self, "input_processor") or self.input_processor is None: logger.warning( f"[{self._name}] No LLM stage found, processors will not be available. " @@ -232,7 +295,14 @@ def shutdown(self): if hasattr(self, "_weak_finalizer"): self._weak_finalizer() - async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator[OmniRequestOutput, None]: + async def generate( + self, + prompt: OmniPromptType, + request_id: str, + sampling_params_list: Sequence[OmniSamplingParams] | None = None, + *, + output_modalities: list[str] | None = None, + ) -> AsyncGenerator[OmniRequestOutput, None]: """Generate outputs for the given prompt asynchronously. Coordinates multi-stage pipeline through YAML configuration. @@ -242,21 +312,13 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator sampling parameters from the sampling_params_list. Args: - *args: Arguments for generation. - - prompt: Prompt to process. Can be a text string, token IDs, - or multimodal prompt. - - request_id: Unique identifier for this request - - sampling_params_list: List of SamplingParams, one for each stage. - Must have the same length as the number of stages. - If None, uses default sampling params for each stage. - **kwargs: Additional arguments for generation. - - prompt: Prompt to process. Can be a text string, token IDs, - or multimodal prompt. - - request_id: Unique identifier for this request - - sampling_params_list: List of SamplingParams, one for each stage. - Must have the same length as the number of stages. - If None, uses default sampling params for each stage. - - output_modalities: Optional list of output modalities. + prompt: Prompt to process. Can be a text string, token IDs, + or multimodal prompt. + request_id: Unique identifier for this request + sampling_params_list: List of SamplingParams, one for each stage. + Must have the same length as the number of stages. + If None, uses default sampling params for each stage. + output_modalities: Optional list of output modalities. Yields: OmniRequestOutput objects as they are produced by each stage. @@ -275,33 +337,9 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator # Start output handler on the first call to generate() self._run_output_handler() - prompt = args[0] if args else kwargs.get("prompt") - request_id = args[1] if len(args) > 1 else kwargs.get("request_id") - sampling_params_list = args[2] if len(args) > 2 else kwargs.get("sampling_params_list") - output_modalities = kwargs.get("output_modalities", None) # TODO: lora_request, trace_headers, priority are not supported yet - if sampling_params_list is None: - # For Omni LLM, the params are parsed via the yaml file. For the current version, - # diffusion params can parsed via the command line. - omni_params_kwargs = { - k: v for k, v in kwargs.items() if k not in ["prompt", "request_id", "output_modalities"] - } - - per_stage_params: list[Any] = [] - for stage_id, stage in enumerate(self.stage_list): - stage_type = getattr(stage, "stage_type", "llm") - if stage_type == "diffusion": - default_dict = self.default_sampling_params_list[stage_id] - # Merge user-provided kwargs - merged = {**default_dict, **omni_params_kwargs} - # Diffusion only needs to keep diff params, will be used via OmniDiffusionRequest - per_stage_params.append(merged) - else: - # LLM directly constructs SamplingParams, don't use the merged params - per_stage_params.append(self.default_sampling_params_list[stage_id]) - - sampling_params_list = per_stage_params + sampling_params_list = self.default_sampling_params_list if len(sampling_params_list) != len(self.stage_list): raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") @@ -320,19 +358,15 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator ) # Metrics/aggregation helper - metrics = OrchestratorMetrics( - num_stages, - self._enable_stats, - _wall_start_ts, + metrics = OrchestratorAggregator( + num_stages=num_stages, + log_stats=self.log_stats, + wall_start_ts=_wall_start_ts, + final_stage_id_for_e2e=final_stage_id_for_e2e, ) - # Seed stage-0 queue with all requests - logger.debug(f"[{self._name}] Seeding request into stage-0") req_state = ClientRequestState(request_id) req_state.metrics = metrics self.request_states[request_id] = req_state - # Mark first input time for stage-0 - metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() - sp0: SamplingParams = sampling_params_list[0] # type: ignore[index] task = { "request_id": request_id, @@ -340,133 +374,48 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator "sampling_params": sp0, } self.stage_list[0].submit(task) + metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() _req_start_ts[request_id] = time.time() - logger.debug(f"[{self._name}] Enqueued request {request_id} to stage-0") - - logger.debug(f"[{self._name}] Entering scheduling loop: stages={num_stages}") - for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): - finished = False - while not finished: - result = await req_state.queue.get() - assert stage_id == req_state.stage_id - - req_id = result.get("request_id") - if "error" in result: - logger.error( - f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}", - ) - raise RuntimeError(result) # Request Finished due to error - - engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") - if isinstance(engine_outputs, list): - engine_outputs = engine_outputs[0] - finished = engine_outputs.finished - - # Mark last output time for this stage whenever we receive outputs - metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) - try: - _m = asdict(result.get("metrics")) - if _m is not None and finished: - metrics.on_stage_metrics(stage_id, req_id, _m) - except Exception as e: - logger.exception( - f"[{self._name}] Failed to process metrics for stage {stage_id}, req {req_id}: {e}", - ) - logger.debug( - f"[{self._name}] Stage-{stage_id} completed request {req_id}; forwarding or finalizing", - ) - - if getattr(stage, "final_output", False): - logger.debug( - f"[{self._name}] Request {req_id} finalized at stage-{stage_id}", - ) - - # End-to-end timing and time-per-token for final output - # (only once per request at the designated final stage) - try: - rid_key = str(req_id) - if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done and finished: - metrics.on_finalize_request( - stage_id, - req_id, - _req_start_ts.get(req_id, _wall_start_ts), - ) - except Exception as e: - logger.exception( - f"[{self._name}] Finalize request handling error for req " - f"{req_id} at stage {stage_id}: {e}", - ) - - # Handle diffusion outputs that already contain images - if stage.final_output_type == "image": - images = [] - if isinstance(engine_outputs, OmniRequestOutput) and engine_outputs.images: - images = engine_outputs.images - elif hasattr(engine_outputs, "images") and engine_outputs.images: - images = engine_outputs.images - yield OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, - request_output=engine_outputs, - images=images, - ) - else: - yield OmniRequestOutput( - stage_id=stage_id, - final_output_type=stage.final_output_type, - request_output=engine_outputs, - ) - if not isinstance(engine_outputs, list): - engine_outputs = [engine_outputs] - stage.set_engine_outputs(engine_outputs) - # Forward to next stage if there is one - next_stage_id = stage_id + 1 - if next_stage_id <= final_stage_id_for_e2e and finished: - next_stage: OmniStage = self.stage_list[next_stage_id] - next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) - sp_next: SamplingParams = sampling_params_list[next_stage_id] - - # Check if we have a connector for this edge - connector_key = (str(stage_id), str(next_stage_id)) - connector = self.connectors.get(connector_key) - - sent_via_connector = False - if connector: - sent_via_connector = try_send_via_connector( - connector=connector, - stage_id=stage_id, - next_stage_id=next_stage_id, - req_id=req_id, - next_inputs=next_inputs, - sampling_params=sp_next, - original_prompt=prompt, - next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, - metrics=metrics, - ) - - if not sent_via_connector: - # Fallback logic removed as we now enforce connector usage. - # If no connector is found or send fails, we log an error and raise, - # because continuing would cause the request to be silently dropped - # and the orchestrator to hang waiting for completion. - error_msg = ( - f"[{self._name}] Failed to send request {req_id} to stage-{next_stage_id} via connector. " - "Configure a connector for this edge or inspect connector logs for details." - ) - logger.error(error_msg) - raise RuntimeError(error_msg) - logger.debug(f"[{self._name}] Forwarded request {req_id} to stage-{next_stage_id}") - else: - logger.debug(f"[{self._name}] Request {req_id} fully completed") - - logger.debug(f"[{self._name}] All requests completed") - - # Summarize and print stats + logger.info( + f"[{self._name}] Entering scheduling loop: stages={num_stages}, final_stage={final_stage_id_for_e2e}" + ) + if self.async_chunk: + stage_queues = {stage_id: asyncio.Queue() for stage_id in range(num_stages)} + req_state.stage_queues = stage_queues + async for output in self._process_async_results( + request_id, + prompt, + sampling_params_list, + req_state, + metrics, + final_stage_id_for_e2e, + ): + yield output + else: + async for output in self._process_sequential_results( + request_id, + req_state, + metrics, + final_stage_id_for_e2e, + sampling_params_list, + prompt, + ): + yield output + + logger.debug(f"[{self._name}] Request {request_id} finalized at stage-{final_stage_id_for_e2e}") try: - summary = metrics.build_and_log_summary(final_stage_id_for_e2e) - logger.info("[Summary] %s", pformat(summary, sort_dicts=False)) + # Finalize E2E metrics if not already done + metrics.on_finalize_request( + final_stage_id_for_e2e, + request_id, + _req_start_ts.get(request_id, _wall_start_ts), + ) + + logger.debug(f"[{self._name}] All requests completed") + # Summarize and print stats + metrics.build_and_log_summary() except Exception as e: - logger.exception(f"[{self._name}] Failed to build/log summary: {e}") + logger.exception(f"[{self._name}] Request {request_id} Failed to finalized/build/log summary: {e}") finally: self.request_states.pop(request_id, None) except (asyncio.CancelledError, GeneratorExit): @@ -474,6 +423,216 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator logger.info("[AsyncOrchestrator] Request %s aborted.", request_id) raise + async def _process_async_results( + self, + request_id: str, + prompt: Any, + sampling_params_list: list[SamplingParams], + req_state: ClientRequestState, + metrics: OrchestratorAggregator, + final_stage_id_for_e2e: int, + ) -> AsyncGenerator[OmniRequestOutput, None]: + all_stages_finished = {stage_id: False for stage_id in range(final_stage_id_for_e2e + 1)} + submit_flag = True + while not all(all_stages_finished.values()): + for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): + if all_stages_finished[stage_id]: + continue + try: + result = req_state.stage_queues[stage_id].get_nowait() + except asyncio.QueueEmpty: + await asyncio.sleep(0.001) + continue + engine_outputs, finished, output_to_yield = self._process_single_result( + result, + stage, + stage_id, + metrics, + ) + if submit_flag and stage_id == 0: + submit_flag = False + prompt_token_ids = engine_outputs.prompt_token_ids + engine_input = copy.deepcopy(prompt) + next_prompt_len = max(1, compute_talker_prompt_ids_length(prompt_token_ids)) + engine_input["prompt_token_ids"] = [0] * next_prompt_len + engine_input["multi_modal_data"] = engine_input["mm_processor_kwargs"] = None + for i in range(1, len(self.stage_list)): + task = { + "request_id": request_id, + "engine_inputs": engine_input, + "sampling_params": sampling_params_list[i], + } + self.stage_list[i].submit(task) + metrics.stage_first_ts[i] = time.time() + all_stages_finished[stage_id] = finished + + if output_to_yield: + yield output_to_yield + + async def _process_sequential_results( + self, + request_id: str, + req_state: ClientRequestState, + metrics: OrchestratorAggregator, + final_stage_id_for_e2e: int, + sampling_params_list: list[SamplingParams], + prompt: Any, + ) -> AsyncGenerator[OmniRequestOutput, None]: + # Track stages that were never submitted (no inputs); skip waiting for them. + # This handles the fan-out topology where Stage-0 forwards to BOTH Stage-1 + # (vision decoder) and Stage-2 (audio decoder) independently, based on + # which token types appeared in Stage-0 output. + skipped_stages: set[int] = set() + for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): + if stage_id in skipped_stages: + continue + finished = False + while not finished: + result = await req_state.queue.get() + assert stage_id == req_state.stage_id + engine_outputs, finished, output_to_yield = self._process_single_result( + result, + stage, + stage_id, + metrics, + ) + if output_to_yield: + yield output_to_yield + if not isinstance(engine_outputs, list): + engine_outputs = [engine_outputs] + stage.set_engine_outputs(engine_outputs) + # Forward to all subsequent stages whose engine_input_source includes + # this stage. Both Stage-1 (vision) and Stage-2 (audio) source from + # Stage-0 independently, so we must try both after Stage-0 completes. + any_forwarded = False + for next_stage_id in range(stage_id + 1, final_stage_id_for_e2e + 1): + next_stage: OmniStage = self.stage_list[next_stage_id] + if stage_id not in getattr(next_stage, "engine_input_source", []): + continue + # Derive inputs for the next stage, record postprocess time + with metrics.stage_postprocess_timer(stage_id, request_id): + next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) + sp_next: SamplingParams = sampling_params_list[next_stage_id] + if not next_inputs: + logger.warning( + "[%s] No inputs for stage-%s (request %s), skipping forward", + self._name, next_stage_id, request_id, + ) + skipped_stages.add(next_stage_id) + continue + # Check if we have a connector for this edge + connector_key = (str(stage_id), str(next_stage_id)) + connector = self.connectors.get(connector_key) + sent_via_connector = False + if connector: + sent_via_connector = try_send_via_connector( + connector=connector, + stage_id=stage_id, + next_stage_id=next_stage_id, + req_id=request_id, + next_inputs=next_inputs, + sampling_params=sp_next, + original_prompt=prompt, + next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, + metrics=metrics, + ) + if not sent_via_connector: + # Fallback logic removed as we now enforce connector usage. + # If no connector is found or send fails, we log an error and raise, + # because continuing would cause the request to be silently dropped + # and the orchestrator to hang waiting for completion. + error_msg = ( + f"[{self._name}] Failed to send request {request_id} to stage-{next_stage_id} via connector. " + "Configure a connector for this edge or inspect connector logs for details." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + logger.debug(f"[{self._name}] Forwarded request {request_id} to stage-{next_stage_id}") + any_forwarded = True + if not any_forwarded: + logger.debug(f"[{self._name}] Request {request_id} fully completed at stage-{stage_id}") + + def _process_single_result( + self, + result: dict[str, Any], + stage: OmniStage, + stage_id: int, + metrics: OrchestratorAggregator, + ) -> tuple[Any, bool, OmniRequestOutput | None]: + """ + Process a single result dictionary from a stage. + Returns: + engine_outputs: The decoded outputs. + finished: Whether the stage processing is finished for this request. + output_to_yield: An OmniRequestOutput to yield, or None. + """ + req_id = result.get("request_id") + + if result.get("skipped"): + logger.info(f"[{self._name}] Stage {stage_id} skipped request {req_id} (no engine inputs)") + class _SkippedOutput: + finished = True + prompt_token_ids = [] + return _SkippedOutput(), True, None + + if "error" in result: + logger.error( + f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}", + ) + raise RuntimeError(result) + + engine_outputs = _load(result, obj_key="engine_outputs", shm_key="engine_outputs_shm") + if isinstance(engine_outputs, list): + engine_outputs = engine_outputs[0] + + finished = engine_outputs.finished + + output_to_yield = None + + if getattr(stage, "final_output", False): + # Construct output to yield + images = [] + if stage.final_output_type == "image": + if isinstance(engine_outputs, OmniRequestOutput) and engine_outputs.images: + images = engine_outputs.images + elif hasattr(engine_outputs, "images") and engine_outputs.images: + images = engine_outputs.images + + if stage.final_output_type == "image": + output_to_yield = OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs, + images=images, + finished=finished, + ) + else: + output_to_yield = OmniRequestOutput( + stage_id=stage_id, + final_output_type=stage.final_output_type, + request_output=engine_outputs, + finished=finished, + ) + # Mark last output time + metrics.stage_last_ts[stage_id] = max(metrics.stage_last_ts[stage_id] or 0.0, time.time()) + + metrics.process_stage_metrics( + result=result, + stage_type=stage.stage_type, + stage_id=stage_id, + req_id=req_id, + engine_outputs=engine_outputs, + finished=finished, + final_output_type=stage.final_output_type, + output_to_yield=output_to_yield, + ) + + logger.debug( + f"[{self._name}] Stage-{stage_id} completed request {req_id}; forwarding or finalizing", + ) + + return engine_outputs, finished, output_to_yield + def _run_output_handler(self) -> None: if self.output_handler is not None: return @@ -503,8 +662,12 @@ async def output_handler(): dropping output for req {req_id} at stage-{stage_id}" ) continue - await req_state.queue.put(result) - req_state.stage_id = stage_id + if hasattr(req_state, "stage_queues") and stage_id in req_state.stage_queues: + await req_state.stage_queues[stage_id].put(result) + else: + # Fallback to old behavior for compatibility + await req_state.queue.put(result) + req_state.stage_id = stage_id if idle: await asyncio.sleep(0.001) # Avoid CPU overload when idle else: @@ -512,7 +675,14 @@ async def output_handler(): except Exception as e: logger.exception("AsyncOmni output_handler failed.") for req_state in request_states.values(): - await req_state.queue.put({"request_id": req_id, "error": str(e)}) + error_msg = {"request_id": req_state.request_id, "error": str(e)} + # Send error to all stage queues + if hasattr(req_state, "stage_queues"): + for queue in req_state.stage_queues.values(): + await queue.put(error_msg) + else: + await req_state.queue.put(error_msg) + error_msg = {"request_id": req_state.request_id, "error": str(e)} self.output_handler = None # Make possible for restart self.output_handler = asyncio.create_task(output_handler()) @@ -579,6 +749,15 @@ async def is_tracing_enabled(self) -> bool: return stage.is_tracing_enabled return False + @property + def renderer(self): + """Return the renderer from input_processor if available. + + OMNI: Required by upstream OpenAIServingModels.__init__ which + accesses engine_client.renderer. + """ + return self.input_processor.renderer + async def do_log_stats(self) -> None: pass diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 1a7f1174c27..2844377badb 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -10,23 +10,44 @@ import asyncio import uuid +import weakref from collections.abc import AsyncGenerator, Iterable from concurrent.futures import ThreadPoolExecutor -from dataclasses import fields from typing import Any -from PIL import Image from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict - -from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig +try: + from huggingface_hub.errors import HFValidationError as _HFValidationError +except ImportError: + _HFValidationError = ValueError + +from vllm_omni.diffusion.data import ( + DiffusionRequestAbortedError, + OmniDiffusionConfig, + TransformerConfig, +) from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType +from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) +def _weak_close_async_omni_diffusion(engine: DiffusionEngine, executor: ThreadPoolExecutor) -> None: + """Best-effort diffusion cleanup for GC finalization.""" + try: + engine.close() + except Exception: + pass + try: + executor.shutdown(wait=False) + except Exception: + pass + + class AsyncOmniDiffusion: """Async entry point for vLLM-Omni diffusion model inference. @@ -53,25 +74,81 @@ def __init__( self, model: str, od_config: OmniDiffusionConfig | None = None, + batch_size: int = 1, **kwargs: Any, ): self.model = model + # Set batch size (default 1 for backward compatibility) + self._batch_size = max(1, batch_size) + + # Capture stage info from kwargs before they might be filtered out + stage_id = kwargs.get("stage_id") + engine_input_source = kwargs.get("engine_input_source") + cfg_kv_collect_func = kwargs.pop("cfg_kv_collect_func", None) + # Build config if od_config is None: od_config = OmniDiffusionConfig.from_kwargs(model=model, **kwargs) elif isinstance(od_config, dict): + # If config is dict, check it too (priority to kwargs if both exist) + if stage_id is None: + stage_id = od_config.get("stage_id") + if engine_input_source is None: + engine_input_source = od_config.get("engine_input_source") od_config = OmniDiffusionConfig.from_kwargs(**od_config) self.od_config = od_config - # Load model class name and transformer config - config_dict = get_hf_file_to_dict("model_index.json", od_config.model) - od_config.model_class_name = config_dict.get("_class_name", None) - od_config.update_multimodal_support() + # Inject stage info into omni_kv_config if present + if stage_id is not None: + self.od_config.omni_kv_config.setdefault("stage_id", stage_id) + if engine_input_source is not None: + self.od_config.omni_kv_config.setdefault("engine_input_source", engine_input_source) - tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + # Diffusers-style models expose `model_index.json` with `_class_name`. + # Non-diffusers models (e.g. Bagel, NextStep) only have `config.json`, + # so we fall back to reading that and mapping model_type manually. + try: + config_dict = get_hf_file_to_dict("model_index.json", od_config.model) + if config_dict is not None: + if od_config.model_class_name is None: + od_config.model_class_name = config_dict.get("_class_name", None) + od_config.update_multimodal_support() + + tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + else: + raise FileNotFoundError("model_index.json not found") + except (AttributeError, KeyError, OSError, ValueError, FileNotFoundError, _HFValidationError): + cfg = get_hf_file_to_dict("config.json", od_config.model) + if cfg is None: + if od_config.model_class_name is not None: + cfg = {} # skip - use explicit model_class_name + else: + raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") + + od_config.tf_model_config = TransformerConfig.from_dict(cfg) + model_type = cfg.get("model_type") + architectures = cfg.get("architectures") or [] + # Bagel/NextStep models don't have a model_index.json, so we set the pipeline class name manually + if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: + od_config.model_class_name = "BagelPipeline" + od_config.tf_model_config = TransformerConfig() + od_config.update_multimodal_support() + elif model_type == "nextstep": + if od_config.model_class_name is None: + od_config.model_class_name = "NextStep11Pipeline" + od_config.tf_model_config = TransformerConfig() + od_config.update_multimodal_support() + elif architectures and len(architectures) == 1: + if od_config.model_class_name is None: + od_config.model_class_name = architectures[0] + elif od_config.model_class_name is None: + raise + + if cfg_kv_collect_func is not None: + od_config.cfg_kv_collect_func = cfg_kv_collect_func # Initialize engine self.engine: DiffusionEngine = DiffusionEngine.make_engine(od_config) @@ -79,67 +156,139 @@ def __init__( # Thread pool for running sync engine in async context self._executor = ThreadPoolExecutor(max_workers=1) self._closed = False + self._weak_finalizer = weakref.finalize( + self, + _weak_close_async_omni_diffusion, + self.engine, + self._executor, + ) - logger.info("AsyncOmniDiffusion initialized with model: %s", model) + logger.info("AsyncOmniDiffusion initialized with model: %s, batch_size: %d", model, self._batch_size) - def _prepare_request( + # ------------------------------------------------------------------ + # batch_size property + # ------------------------------------------------------------------ + + @property + def batch_size(self) -> int: + """Return the configured batch size for request batching.""" + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int) -> None: + if not isinstance(value, int) or value < 1: + raise ValueError("batch_size must be a positive integer") + self._batch_size = value + + # ------------------------------------------------------------------ + # Public batch generation API + # ------------------------------------------------------------------ + + async def generate_batch( self, - prompt: str, + prompts: list[OmniPromptType], + sampling_params: OmniDiffusionSamplingParams, request_id: str | None = None, - **kwargs: Any, - ) -> OmniDiffusionRequest: - """Prepare a diffusion request from prompt and parameters. + lora_request: LoRARequest | None = None, + ) -> OmniRequestOutput: + """Generate images from multiple prompts in a single engine call. + + Batches the given prompts into **one** ``DiffusionEngine.step()`` + call and returns a single ``OmniRequestOutput`` containing all + generated images. Called by ``StageDiffusionClient._run_batch`` + when the orchestrator receives a list-prompt request. Args: - prompt: Text prompt for image generation - request_id: Optional unique identifier for the request - **kwargs: Additional generation parameters + prompts: List of text prompts describing the desired images. + sampling_params: Shared sampling parameters for all prompts. + request_id: Optional unique identifier. Auto-generated when *None*. + lora_request: Optional LoRA adapter to apply. Returns: - OmniDiffusionRequest ready for processing + A single ``OmniRequestOutput`` with all images combined. """ if request_id is None: - request_id = f"diff-{uuid.uuid4().hex[:16]}" + request_id = f"diff-batch-{uuid.uuid4().hex[:8]}" + return await self._generate_batch(prompts, sampling_params, request_id, lora_request) + + # ------------------------------------------------------------------ + # Internal batch generation + # ------------------------------------------------------------------ + + async def _generate_batch( + self, + prompts: list[OmniPromptType], + sampling_params: OmniDiffusionSamplingParams, + request_id: str, + lora_request: LoRARequest | None = None, + ) -> OmniRequestOutput: + """Generate images from multiple prompts in a single engine call.""" + if not prompts: + return OmniRequestOutput(request_id=request_id, images=[], final_output_type="image") + + if sampling_params.guidance_scale: + sampling_params.guidance_scale_provided = True - field_names = {f.name for f in fields(OmniDiffusionRequest)} + if lora_request is not None: + sampling_params.lora_request = lora_request + + request = OmniDiffusionRequest( + prompts=prompts, + sampling_params=sampling_params, + request_ids=[f"{request_id}-{i}" for i in range(len(prompts))], + ) - init_kwargs = { - "prompt": prompt, - "request_id": request_id, - } + logger.debug("Starting batch generation for %d prompts, request_id=%s", len(prompts), request_id) - for key, value in kwargs.items(): - if key in field_names: - init_kwargs[key] = value + loop = asyncio.get_event_loop() + try: + results = await loop.run_in_executor( + self._executor, + self.engine.step, + request, + ) + except Exception as e: + logger.error("Batch generation failed for request %s: %s", request_id, e) + raise RuntimeError(f"Diffusion batch generation failed: {e}") from e - return OmniDiffusionRequest(**init_kwargs) + # Combine all per-prompt results into a single OmniRequestOutput + all_images = [] + for result in results: + all_images.extend(result.images) + + return OmniRequestOutput( + request_id=request_id, + images=all_images, + final_output_type="image", + finished=True, + ) + + def get_diffusion_od_config(self) -> OmniDiffusionConfig: + """Return the diffusion config used by this engine.""" + return self.od_config + + # ------------------------------------------------------------------ + # Public generate API + # ------------------------------------------------------------------ async def generate( self, - prompt: str, + prompt: OmniPromptType, + sampling_params: OmniDiffusionSamplingParams, request_id: str | None = None, - num_inference_steps: int = 50, - guidance_scale: float | None = None, - height: int | None = None, - width: int | None = None, - negative_prompt: str | None = None, - num_outputs_per_prompt: int = 1, - seed: int | None = None, - **kwargs: Any, + lora_request: LoRARequest | None = None, ) -> OmniRequestOutput: - """Generate images asynchronously from a text prompt. + """Generate images asynchronously from a single text prompt. + + For batched generation (multiple prompts in one engine call), use + :meth:`generate_batch` instead. This method always processes + exactly one prompt per call. Args: prompt: Text prompt describing the desired image + sampling_params: Sampling parameters request_id: Optional unique identifier for tracking the request - num_inference_steps: Number of denoising steps (default: 50) - guidance_scale: Classifier-free guidance scale (optional, uses model defaults if omitted) - height: Optional image height in pixels - width: Optional image width in pixels - negative_prompt: Optional negative prompt for guidance - num_outputs_per_prompt: Number of images to generate (default: 1) - seed: Optional random seed for reproducibility - **kwargs: Additional generation parameters + lora_request: Optional LoRA adapter to apply Returns: OmniRequestOutput containing generated images @@ -149,64 +298,49 @@ async def generate( """ if request_id is None: request_id = f"diff-{uuid.uuid4().hex[:16]}" - - # Prepare request - request_kwargs = { - "prompt": prompt, - "request_id": request_id, - "num_inference_steps": num_inference_steps, - "height": height, - "width": width, - "negative_prompt": negative_prompt, - "num_outputs_per_prompt": num_outputs_per_prompt, - "seed": seed, - **kwargs, - } - if guidance_scale is not None: - request_kwargs["guidance_scale"] = guidance_scale - - request = self._prepare_request(**request_kwargs) + if sampling_params.guidance_scale: + sampling_params.guidance_scale_provided = True + + if lora_request is not None: + sampling_params.lora_request = lora_request + + # Extract additional_information from OmniTokensPrompt into extra dict + # (carries audio_tokens, vision_tokens, etc. from thinker2*_decoder processors) + extra: dict = {} + if isinstance(prompt, dict) and prompt.get('additional_information'): + extra.update(prompt['additional_information']) + elif hasattr(prompt, 'additional_information') and prompt.additional_information: + extra.update(prompt.additional_information) + + request = OmniDiffusionRequest( + prompts=[prompt], + sampling_params=sampling_params, + request_ids=[request_id], + extra=extra if extra else {}, + ) logger.debug("Starting generation for request %s", request_id) - # Run engine in thread pool loop = asyncio.get_event_loop() try: result = await loop.run_in_executor( self._executor, self.engine.step, - [request], + request, ) + result = result[0] + except asyncio.CancelledError: + self.engine.abort(request_id) + raise + except DiffusionRequestAbortedError: + raise except Exception as e: logger.error("Generation failed for request %s: %s", request_id, e) raise RuntimeError(f"Diffusion generation failed: {e}") from e - # Check if result is already OmniRequestOutput - if isinstance(result, OmniRequestOutput): - # Update request_id if needed - if not result.request_id: - result.request_id = request_id - return result - - # Process results if not OmniRequestOutput - images: list[Image.Image] = [] - if result is not None: - if isinstance(result, list): - for item in result: - if isinstance(item, Image.Image): - images.append(item) - elif isinstance(result, Image.Image): - images.append(result) - - return OmniRequestOutput.from_diffusion( - request_id=request_id, - images=images, - prompt=prompt, - metrics={ - "num_inference_steps": num_inference_steps, - "guidance_scale": request.guidance_scale, - }, - ) + if not result.request_id: + result.request_id = request_id + return result async def generate_stream( self, @@ -240,6 +374,10 @@ def close(self) -> None: return self._closed = True + finalizer = getattr(self, "_weak_finalizer", None) + if finalizer is not None and finalizer.alive: + finalizer.detach() + try: self.engine.close() except Exception as e: @@ -256,13 +394,6 @@ def shutdown(self) -> None: """Alias for close() method.""" self.close() - def __del__(self) -> None: - """Best-effort cleanup on deletion.""" - try: - self.close() - except Exception: - pass - async def abort(self, request_id: str | Iterable[str]) -> None: """Abort a request.""" self.engine.abort(request_id) @@ -276,3 +407,84 @@ def is_running(self) -> bool: def is_stopped(self) -> bool: """Check if the engine is stopped.""" return self._closed + + async def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "remove_lora", + None, + (adapter_id,), + {}, + None, + ) + return all(results) if isinstance(results, list) else results + + async def add_lora(self, lora_request: LoRARequest) -> bool: + """Add a LoRA adapter""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "add_lora", + None, + (), + {"lora_request": lora_request}, + None, + ) + return all(results) if isinstance(results, list) else results + + async def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs.""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "list_loras", + None, + (), + {}, + None, + ) + # collective_rpc returns list from workers; flatten unique ids + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + merged.update(part or []) + return sorted(merged) + + async def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "pin_lora", + None, + (), + {"adapter_id": lora_id}, + None, + ) + return all(results) if isinstance(results, list) else results + + async def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: + """Start or stop profiling for the diffusion model. + + Args: + is_start: True to start profiling, False to stop. + profile_prefix: Optional prefix for trace filename (vLLM compat). + + Note: + Matches vLLM's worker.profile() signature for consistency. + Traces are saved automatically via on_trace_ready callback. + """ + loop = asyncio.get_event_loop() + await loop.run_in_executor( + self._executor, + self.engine.profile, + is_start, + profile_prefix, + ) diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py index 567af03770f..20fb787ecd8 100644 --- a/vllm_omni/entrypoints/async_omni_llm.py +++ b/vllm_omni/entrypoints/async_omni_llm.py @@ -6,11 +6,10 @@ from typing import TYPE_CHECKING import torch -import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.tokenizers import init_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext @@ -111,12 +110,11 @@ def __init__( tokenizer = None else: # Tokenizer (+ ensure liveness if running in another process). - tokenizer = init_tokenizer_from_config(model_config=vllm_config.model_config) + tokenizer = cached_tokenizer_from_config(model_config=vllm_config.model_config) # InputProcessor (converts Inputs --> EngineCoreRequests). self.input_processor = OmniInputProcessor( vllm_config=vllm_config, - tokenizer=tokenizer, mm_registry=mm_registry, ) @@ -135,6 +133,10 @@ def __init__( self._pause_cond = asyncio.Condition() self._paused = False + # Set renderer for output handler compatibility with AsyncLLM + from vllm.renderers import renderer_from_config as _renderer_from_config + self.renderer = _renderer_from_config(self.vllm_config) + # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_async_mp_client( vllm_config=vllm_config, @@ -165,21 +167,23 @@ def __init__( except RuntimeError: pass - if envs.VLLM_TORCH_PROFILER_DIR and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: + # Use profiler_config from vllm_config (new way, aligned with vllm v1) + if vllm_config.profiler_config.profiler == "torch" and not vllm_config.profiler_config.ignore_frontend: + profiler_dir = vllm_config.profiler_config.torch_profiler_dir logger.info( "Torch profiler enabled. AsyncOmniLLM CPU traces will be collected under %s", - envs.VLLM_TORCH_PROFILER_DIR, + profiler_dir, ) worker_name = f"{socket.gethostname()}_{os.getpid()}.async_omni_llm" self.profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, ], - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_stack=vllm_config.profiler_config.torch_profiler_with_stack, on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, + profiler_dir, worker_name=worker_name, - use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, + use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip, ), ) else: diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 26ac36ea8b4..d77ce99bd9d 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -15,9 +15,12 @@ import sys import time import traceback +from collections.abc import Sequence +from contextlib import contextmanager from dataclasses import fields -from typing import Any +from typing import Any, Literal, cast +from vllm import PromptType, RequestOutput from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -31,27 +34,199 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.distributed.omni_connectors import build_stage_connectors from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector +from vllm_omni.distributed.omni_connectors.connectors.base import OmniConnectorBase from vllm_omni.distributed.ray_utils.utils import kill_ray_actor, start_ray_actor -from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs +from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs, OmniEngineArgs from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion from vllm_omni.entrypoints.async_omni_llm import AsyncOmniLLM -from vllm_omni.entrypoints.log_utils import count_tokens_from_outputs from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion from vllm_omni.entrypoints.omni_llm import OmniLLM from vllm_omni.entrypoints.stage_utils import ( SHUTDOWN_TASK, OmniStageTaskType, + _resolve_model_tokenizer_paths, _to_dict, is_profiler_task, maybe_dump_to_shm, set_stage_devices, ) -from vllm_omni.inputs.data import OmniTokensPrompt -from vllm_omni.utils import detect_device_type +from vllm_omni.entrypoints.utils import detect_pid_host, filter_dataclass_kwargs +from vllm_omni.entrypoints.zmq_utils import ( + ZmqQueue, + create_zmq_queue, +) +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams, OmniTokensPrompt +from vllm_omni.metrics import count_tokens_from_outputs +from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) +@contextmanager +def _sequential_init_lock(engine_args: dict[str, Any], stage_init_timeout: int = 300): + """Acquire device locks for sequential init if NVML is unavailable. + + If process-scoped memory tracking is available (NVML works), stages can + safely initialize concurrently — each measures only its own GPU memory. + Otherwise, fall back to file-based locks to serialize initialization. + """ + from vllm_omni.worker.gpu_memory_utils import is_process_scoped_memory_available + + nvml_available = is_process_scoped_memory_available() + pid_host = detect_pid_host() + + if nvml_available and pid_host: + logger.info( + "NVML process-scoped memory available and PID host is available — concurrent init is safe, skipping locks" + ) + yield + return + else: + logger.info( + "Using sequential init locks (nvml_available=%s, pid_host=%s)", + nvml_available, + pid_host, + ) + + from vllm_omni.platforms import current_omni_platform + + # Get all parallel sizes from engine_args or parallel_config (defaults to 1) + if "parallel_config" in engine_args: + parallel_config = engine_args["parallel_config"] + tensor_parallel_size = parallel_config.get("tensor_parallel_size", 1) + pipeline_parallel_size = parallel_config.get("pipeline_parallel_size", 1) + data_parallel_size = parallel_config.get("data_parallel_size", 1) + prefill_context_parallel_size = parallel_config.get("prefill_context_parallel_size", 1) + sequence_parallel_size = parallel_config.get("sequence_parallel_size", 1) + cfg_parallel_size = parallel_config.get("cfg_parallel_size", 1) + else: + tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) + pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) + data_parallel_size = engine_args.get("data_parallel_size", 1) + prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) + sequence_parallel_size = 1 + cfg_parallel_size = 1 + + num_devices_per_stage = ( + tensor_parallel_size + * pipeline_parallel_size + * data_parallel_size + * prefill_context_parallel_size + * sequence_parallel_size + * cfg_parallel_size + ) + + # Get physical device IDs from device control env var + device_control_env = current_omni_platform.device_control_env_var + visible_devices_str = os.environ.get(device_control_env) + physical_devices = [] + + if visible_devices_str: + try: + physical_devices = [int(x.strip()) for x in visible_devices_str.split(",") if x.strip()] + except (ValueError, IndexError): + pass + + if not physical_devices: + num_devices = current_omni_platform.get_device_count() + physical_devices = list(range(num_devices)) + + num_devices_to_lock = min(num_devices_per_stage, len(physical_devices)) + devices_to_lock = sorted(physical_devices[:num_devices_to_lock]) + + logger.debug( + "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d, SP=%d, CFG=%d; will lock %d devices: %s", + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, + prefill_context_parallel_size, + sequence_parallel_size, + cfg_parallel_size, + num_devices_to_lock, + devices_to_lock, + ) + + # Acquire exclusive locks for all devices using fcntl.flock + wait_start = time.time() + acquired_lock_fds = [] + + for device_id in devices_to_lock: + lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock" + lock_acquired = False + + while not lock_acquired: + try: + lock_fd = os.open(lock_file, os.O_CREAT | os.O_RDWR, 0o644) + + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + os.ftruncate(lock_fd, 0) + os.write(lock_fd, f"{os.getpid()}\n".encode()) + os.fsync(lock_fd) + lock_acquired = True + acquired_lock_fds.append(lock_fd) + logger.debug("Acquired exclusive lock for device %s", device_id) + except BlockingIOError: + os.close(lock_fd) + + if time.time() - wait_start > stage_init_timeout: + logger.warning( + "Timeout waiting for device %s initialization lock, proceeding anyway", + device_id, + ) + break + + time.sleep(0.1) + except OSError as e: + logger.debug( + "Failed to acquire lock for device %s: %s, continuing anyway", + device_id, + e, + ) + try: + os.close(lock_fd) + except (OSError, NameError): + pass + break + + # Set FD_CLOEXEC to prevent child processes from inheriting locks + for lock_fd in acquired_lock_fds: + try: + flags = fcntl.fcntl(lock_fd, fcntl.F_GETFD) + fcntl.fcntl(lock_fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + except (OSError, ValueError): + pass + + try: + yield + finally: + for lock_fd in acquired_lock_fds: + try: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + os.close(lock_fd) + logger.debug("Released initialization lock (fd=%s)", lock_fd) + except (OSError, ValueError): + pass + + +def _resolve_worker_cls(engine_args: dict[str, Any]) -> None: + worker_type = engine_args.get("worker_type", None) + if not worker_type: + return + worker_cls = engine_args.get("worker_cls") + if worker_cls is not None and worker_cls != "auto": + return + from vllm_omni.platforms import current_omni_platform + + worker_type = str(worker_type).lower() + if worker_type == "ar": + engine_args["worker_cls"] = current_omni_platform.get_omni_ar_worker_cls() + elif worker_type == "generation": + engine_args["worker_cls"] = current_omni_platform.get_omni_generation_worker_cls() + else: + raise ValueError(f"Unknown worker_type: {worker_type}") + + def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: """Build OmniDiffusionConfig kwargs from engine args.""" od_config = engine_args.get("od_config", {}) @@ -61,38 +236,10 @@ def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: for key, value in engine_args.items(): if key in od_field_names: od_config[key] = value + od_config["model"] = model # restore resolved path return od_config -def prepare_sampling_params(sampling_params: Any, stage_type: str) -> Any: - """Prepare sampling parameters for the given stage type. - - Args: - sampling_params: Raw sampling parameters (dict or SamplingParams) - stage_type: Either "llm" or "diffusion" - - Returns: - Processed sampling parameters ready for engine consumption - """ - if stage_type == "diffusion": - # For diffusion stages: extract kwargs, handling different input types - if isinstance(sampling_params, dict): - diffusion_kwargs = dict(sampling_params) - else: - diffusion_kwargs = getattr(sampling_params, "__dict__", {}) or {} - - # Remove 'prompt' and 'request_id' to avoid conflict with explicit arguments - diffusion_kwargs.pop("prompt", None) - diffusion_kwargs.pop("request_id", None) - return diffusion_kwargs - - else: # stage_type == "llm" - # For LLM stages: ensure we have a SamplingParams object - if isinstance(sampling_params, dict): - return SamplingParams(**sampling_params) - return sampling_params - - class OmniStage: """Stage manager for orchestrating a single stage in the omni pipeline. @@ -106,7 +253,7 @@ class OmniStage: """ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): - logger.info(f"[OmniStage] stage_config: {stage_config}") + logger.debug(f"[OmniStage] stage_config: {stage_config}") self.stage_config = stage_config self.engine = None self.async_engine = None @@ -123,7 +270,13 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): self.engine_outputs = None self.is_comprehension = getattr(stage_config, "is_comprehension", False) # Support for different stage types: "llm" (default) or "diffusion" - self.stage_type = getattr(stage_config, "stage_type", "llm") + self.stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm") + if ( + "stage_id" in stage_config.engine_args + and stage_config.engine_args.stage_id != self.stage_id + and self.stage_id is not None + ): + stage_config.engine_args.stage_id = self.stage_id if hasattr(stage_config, "custom_process_input_func"): # Import the module specified in the config (already a full module path) module_path, func_name = stage_config.custom_process_input_func.rsplit(".", 1) @@ -134,13 +287,21 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): self.final_output = getattr(stage_config, "final_output", False) self.final_output_type = getattr(stage_config, "final_output_type", None) + self.tts_args = _to_dict(getattr(stage_config, "tts_args", {})) default_sampling_params = getattr(stage_config, "default_sampling_params", {}) # For LLM stage, this can directly be a SamplingParams-compatible dict; # For diffusion stage, this only serves as default values for diffusion kwargs. - self.default_sampling_params = _to_dict(default_sampling_params) + default_sampling_params = _to_dict(default_sampling_params) + # Further convert it to dataclass to check fields + try: + self.default_sampling_params = ( + SamplingParams if self.stage_type == "llm" else OmniDiffusionSamplingParams + )(**default_sampling_params) + except TypeError as error: + raise TypeError(f"Invalid default_sampling_params for stage {self.stage_id}: {error}") from error # Runtime orchestration state (added) - self._in_q: mp.Queue | None = None - self._out_q: mp.Queue | None = None + self._in_q: mp.queues.Queue | ZmqQueue | str | None = None + self._out_q: mp.queues.Queue | ZmqQueue | str | None = None self._proc: mp.Process | None = None self._shm_threshold_bytes: int = 65536 self._stage_init_timeout: int = stage_init_timeout @@ -202,12 +363,16 @@ def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None: self.engine_outputs = engine_outputs # ----------------- New Orchestration APIs ----------------- - def attach_queues(self, in_q: mp.Queue, out_q: mp.Queue) -> None: + def attach_queues( + self, + in_q: mp.queues.Queue | ZmqQueue | str | None, + out_q: mp.queues.Queue | ZmqQueue | str | None, + ) -> None: """Attach input and output queues for IPC communication. Args: - in_q: Input queue for receiving tasks from orchestrator - out_q: Output queue for sending results to orchestrator + in_q: Input queue for receiving tasks from orchestrator (queue object or endpoint string) + out_q: Output queue for sending results to orchestrator (queue object or endpoint string) """ self._in_q = in_q self._out_q = out_q @@ -223,8 +388,8 @@ def stop_profile(self) -> dict: # Wait for result from worker try: - # Profiling stop might take time to flush files, give it 180s - response = self._out_q.get(timeout=60000) + # Profiling stop might take time to flush files, give it 600s + response = self._out_q.get(timeout=600) if isinstance(response, dict): if response.get("type") == "profiler_result": @@ -254,6 +419,7 @@ def init_stage_worker( batch_timeout: int = 10, connectors_config: dict | None = None, worker_backend: str = "multi_process", + ignore_runtime_config: bool = False, **kwargs: Any, ) -> None: """Initialize and start the stage worker process. @@ -269,6 +435,7 @@ def init_stage_worker( batch_timeout: Timeout in seconds for batching requests connectors_config: Configuration for stage connectors worker_backend: Backend type ("multi_process" or "ray") + ignore_runtime_config: Whether to ignore runtime configuration (default: False) **kwargs: Additional arguments (e.g. ray_placement_group) Raises: @@ -286,7 +453,10 @@ def init_stage_worker( ctx = ctx or mp.get_context("spawn") # Prepare lightweight dict config for worker engine_args = _to_dict(self.engine_args) - runtime_cfg = _to_dict(getattr(self.stage_config, "runtime", {})) + if ignore_runtime_config: + runtime_cfg = {} + else: + runtime_cfg = _to_dict(getattr(self.stage_config, "runtime", {})) stage_payload: dict[str, Any] = { "stage_id": self.stage_id, "engine_args": engine_args, @@ -294,6 +464,9 @@ def init_stage_worker( "shm_threshold_bytes": self._shm_threshold_bytes, "connectors_config": connectors_config or {}, "stage_type": self.stage_type, + "engine_input_source": self.engine_input_source, + "final_output": self.final_output, + "final_output_type": self.final_output_type, } try: old_env = os.environ.get("VLLM_LOGGING_PREFIX") @@ -305,9 +478,10 @@ def init_stage_worker( _stage_worker_async_entry, ray_placement_group, self.stage_id, - self, model=model, stage_payload=stage_payload, + in_q=self._in_q, + out_q=self._out_q, batch_timeout=batch_timeout, stage_init_timeout=self._stage_init_timeout, ) @@ -328,9 +502,10 @@ def init_stage_worker( self._proc = ctx.Process( target=_stage_worker_async_entry, args=( - self, model, stage_payload, + self._in_q.endpoint if isinstance(self._in_q, ZmqQueue) else self._in_q, + self._out_q.endpoint if isinstance(self._out_q, ZmqQueue) else self._out_q, batch_timeout, self._stage_init_timeout, ), @@ -341,8 +516,8 @@ def init_stage_worker( args=( model, stage_payload, - self._in_q, - self._out_q, + self._in_q.endpoint if isinstance(self._in_q, ZmqQueue) else self._in_q, + self._out_q.endpoint if isinstance(self._out_q, ZmqQueue) else self._out_q, batch_timeout, self._stage_init_timeout, ), @@ -366,6 +541,13 @@ def stop_stage_worker(self) -> None: self._in_q.put_nowait(SHUTDOWN_TASK) except Exception as e: logger.warning("Failed to send shutdown to in_q: %s", e) + close_fn = getattr(self._in_q, "close", None) + if callable(close_fn): + close_fn() + if self._out_q is not None: + close_fn = getattr(self._out_q, "close", None) + if callable(close_fn): + close_fn() if hasattr(self, "_ray_actor") and self._ray_actor: kill_ray_actor(self._ray_actor) @@ -389,6 +571,36 @@ def submit(self, payload: dict[str, Any]) -> None: sampling_params, etc.) """ assert self._in_q is not None + + # [Omni] Inject global request_id into additional_information for cross-stage ID consistency + # This allows workers (like GPUARModelRunner) to use the global ID for side-channel + # operations like KV transfer, even if they use internal IDs for execution. + if "request_id" in payload and "engine_inputs" in payload: + req_id = payload["request_id"] + ein = payload["engine_inputs"] + + # Helper to inject into additional_information + def _inject_global_id(target_ein): + # OmniTokensPrompt is a TypedDict at runtime, so we treat it as a dict + if isinstance(target_ein, dict): + if "additional_information" not in target_ein: + target_ein["additional_information"] = {} + + # Ensure additional_information is a dict before assignment + # (in case it was somehow initialized as None or other type) + if target_ein["additional_information"] is None: + target_ein["additional_information"] = {} + + if isinstance(target_ein["additional_information"], dict): + # Wrap in list because OmniInputProcessor requires Tensor or list values + target_ein["additional_information"]["global_request_id"] = [str(req_id)] + + if isinstance(ein, list): + for item in ein: + _inject_global_id(item) + else: + _inject_global_id(ein) + self._in_q.put(payload) def try_collect(self) -> dict[str, Any] | None: @@ -399,10 +611,27 @@ def try_collect(self) -> dict[str, Any] | None: request_id, engine_outputs (or engine_outputs_shm), and metrics. """ assert self._out_q is not None + # Ensure transformers_modules (trust_remote_code cache) is importable + # in this process before pickle deserialization of Stage-0 output. + import os as _os, sys as _sys + _hf_modules = _os.path.join( + _os.environ.get("HF_HOME", _os.path.join(_os.path.expanduser("~"), ".cache", "huggingface")), + "modules" + ) + if _hf_modules not in _sys.path: + _sys.path.insert(0, _hf_modules) try: return self._out_q.get_nowait() - except Exception: + except queue.Empty: return None + except Exception as _e: + import logging as _lg + _lg.getLogger(__name__).error("[Stage-%s] try_collect deser error: %s", self.stage_id, _e) + # Message was consumed but deserialization failed (e.g. transformers_modules not loaded). + # Return minimal stage_ready so the orchestrator marks this stage as ready + # and triggers the engine_args fallback in _wait_for_stages_ready. + return {"type": "stage_ready", "stage_id": self.stage_id, + "vllm_config": None, "tokenizer": None} def process_engine_inputs( self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None @@ -458,23 +687,63 @@ def process_engine_inputs( def _stage_worker( model: str, stage_payload: dict[str, Any], - in_q: mp.Queue, - out_q: mp.Queue, + in_q: mp.queues.Queue | ZmqQueue | str, + out_q: mp.queues.Queue | ZmqQueue | str, batch_timeout: int = 10, stage_init_timeout: int = 300, ) -> None: """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" # Use local aliases to avoid conflicts with global imports in worker process logger.info(f"Starting stage worker with model: {model}") + import multiprocessing as _mp import os as _os import time as _time + import zmq + + from vllm_omni.plugins import load_omni_general_plugins + + load_omni_general_plugins() + # IMPORTANT: Ensure vLLM's internal multiprocessing workers (e.g., GPUARWorker / + # GPUARModelRunner) are spawned with a fork-safe method. + # Mooncake / gRPC / RDMA and CUDA/NCCL can deadlock under fork-with-threads. + if _os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": + _os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + logger.info("[Stage] Set VLLM_WORKER_MULTIPROC_METHOD=spawn") + # Best-effort: also force python mp start method in this stage process. + # This may raise if already set; that's fine. + try: + _mp.set_start_method("spawn", force=True) + except RuntimeError: + pass + stage_id = stage_payload["stage_id"] engine_args = stage_payload.get("engine_args", {}) runtime_cfg = stage_payload.get("runtime", {}) shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) connectors_config = stage_payload.get("connectors_config", {}) - stage_type = stage_payload.get("stage_type", "llm") + stage_type: Literal["llm", "diffusion"] = stage_payload.get("stage_type", "llm") + + if stage_type != "diffusion": + _resolve_worker_cls(engine_args) + + # Handle non-standard model directory structures (e.g., tokenizer in root, model in subdir) + model = _resolve_model_tokenizer_paths(model, engine_args) + + # Resolve ZMQ queue endpoints if needed + zmq_ctx = None + if isinstance(in_q, str) or isinstance(out_q, str): + zmq_ctx = zmq.Context() + if isinstance(in_q, str): + in_q = create_zmq_queue(zmq_ctx, in_q, zmq.PULL) + if isinstance(out_q, str): + out_q = create_zmq_queue(zmq_ctx, out_q, zmq.PUSH) + # When using ZMQ (cross-node IPC), disable SHM so data is sent inline. + shm_threshold_bytes = sys.maxsize + logger.info( + "[Stage-%s] ZMQ transport detected; disabling SHM IPC (shm_threshold_bytes set to maxsize)", + stage_id, + ) # Aggregates for running average _agg_total_tokens = 0 @@ -485,184 +754,53 @@ def _stage_worker( # Device mapping device_type = None try: - device_type = detect_device_type() + from vllm_omni.platforms import current_omni_platform + + device_type = current_omni_platform.device_type set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: logger.warning("Device setup failed: %s", e) - # Sequential initialization on the same device to avoid memory calculation errors - # when multiple instances start simultaneously - # For TP/PP/DP/SP, we need to lock ALL devices that will be used by this stage - lock_files = [] - if device_type == "cuda": - try: - import torch - - if torch.cuda.is_available(): - # Get all parallel sizes from engine_args or parallel_config (defaults to 1) - if "parallel_config" in engine_args: - parallel_config = engine_args["parallel_config"] - tensor_parallel_size = parallel_config.get("tensor_parallel_size", 1) - pipeline_parallel_size = parallel_config.get("pipeline_parallel_size", 1) - data_parallel_size = parallel_config.get("data_parallel_size", 1) - prefill_context_parallel_size = 1 # not used for diffusion - sequence_parallel_size = parallel_config.get("sequence_parallel_size", 1) - cfg_parallel_size = parallel_config.get("cfg_parallel_size", 1) - else: - tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) - pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) - data_parallel_size = engine_args.get("data_parallel_size", 1) - prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) - sequence_parallel_size = 1 # not use in omni model - cfg_parallel_size = 1 # not used in omni model - - # Calculate total number of devices needed for this stage - # For a single stage worker: - # - TP: splits model across GPUs (always needed) - # - PP: splits layers across pipelinestages, but each stage uses TP devices - # - DP: replicates model, but each replica uses TP devices - # - PCP: context parallelism, typically uses TP devices - # - SP: sequence parallelism, typically uses TP devices - # - CFG: Classifier-Free Guidance parallelism for diffusion models - # The number of devices per stage is determined by TP * PP * DP * PCP * SP * CFG size - # (PP/DP/PCP are higher-level parallelism that don't add devices per stage) - num_devices_per_stage = ( - tensor_parallel_size - * pipeline_parallel_size - * data_parallel_size - * prefill_context_parallel_size - * sequence_parallel_size - * cfg_parallel_size - ) - - # Get physical device IDs from CUDA_VISIBLE_DEVICES - # After set_stage_devices, CUDA_VISIBLE_DEVICES is set to physical device(s) - cuda_visible_devices = _os.environ.get("CUDA_VISIBLE_DEVICES") - physical_devices = [] - - if cuda_visible_devices: - try: - physical_devices = [int(x.strip()) for x in cuda_visible_devices.split(",") if x.strip()] - except (ValueError, IndexError): - pass - - if not physical_devices: - # Fallback: use logical device count if CUDA_VISIBLE_DEVICES not set - num_devices = torch.cuda.device_count() - physical_devices = list(range(num_devices)) - - # Determine which devices will be used (min of devices per stage and available devices) - num_devices_to_lock = min(num_devices_per_stage, len(physical_devices)) - devices_to_lock = physical_devices[:num_devices_to_lock] - - # Sort devices_to_lock to prevent deadlock (all processes acquire locks in same order) - devices_to_lock = sorted(devices_to_lock) - - logger.debug( - "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d, SP=%d; will lock %d devices: %s", - tensor_parallel_size, - pipeline_parallel_size, - data_parallel_size, - prefill_context_parallel_size, - sequence_parallel_size, - num_devices_to_lock, - devices_to_lock, - ) - - # Acquire exclusive locks for all devices using fcntl.flock - # Locks are automatically released when process dies - wait_start = _time.time() - acquired_lock_fds = [] # Store file descriptors to keep locks alive - - for device_id in devices_to_lock: - lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock" - lock_acquired = False - - while not lock_acquired: - try: - # Open or create the lock file - lock_fd = _os.open(lock_file, _os.O_CREAT | _os.O_RDWR, 0o644) - - # Try to acquire exclusive lock (non-blocking first) - try: - fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) - # Successfully acquired lock - write PID - _os.ftruncate(lock_fd, 0) # Clear file - _os.write(lock_fd, f"{_os.getpid()}\n".encode()) - _os.fsync(lock_fd) # Ensure written to disk - lock_acquired = True - acquired_lock_fds.append(lock_fd) - logger.debug("Acquired exclusive lock for device %s", device_id) - except BlockingIOError: - # Lock is held by another process - _os.close(lock_fd) - - # Check if we've been waiting too long - if _time.time() - wait_start > stage_init_timeout: - logger.warning( - "Timeout waiting for device %s initialization lock, proceeding anyway", - device_id, - ) - break - - # Wait a bit before retrying - _time.sleep(0.1) - except OSError as e: - # Other error - log and continue without lock - logger.debug( - "Failed to acquire lock for device %s: %s, continuing anyway", - device_id, - e, - ) - try: - _os.close(lock_fd) - except (OSError, NameError): - pass - break - - lock_files = acquired_lock_fds - except Exception as e: - logger.debug( - "[Stage-%s] Failed to set up sequential initialization lock: %s", - stage_id, - e, - ) - # Init engine based on stage_type - logger.debug( - "[Stage-%s] Initializing %s engine with args keys=%s", - stage_id, - stage_type, - list(engine_args.keys()), - ) - try: + # Use sequential init locks only when NVML is unavailable + with _sequential_init_lock(engine_args, stage_init_timeout): + # Init engine based on stage_type + logger.debug( + "[Stage-%s] Initializing %s engine with args keys=%s", stage_id, stage_type, list(engine_args.keys()) + ) + if engine_args.get("async_chunk", False): + logger.debug("[Stage-%s] Async chunk enabled, injecting connectors config", stage_id) + stage_connector_spec = {} + for v in connectors_config.values(): + stage_connector_spec = dict(v.get("spec", {})) + break + engine_args["stage_connector_spec"] = stage_connector_spec + engine_args["stage_id"] = stage_id if stage_type == "diffusion": - engine_args.pop("model_stage") - stage_engine = OmniDiffusion(**engine_args) + engine_args = filter_dataclass_kwargs(OmniDiffusionConfig, engine_args) + engine_args.pop("model_stage", None) + engine_args.pop("model", None) + stage_engine = OmniDiffusion( + model=model, + stage_id=stage_id, + engine_input_source=stage_payload.get("engine_input_source", []), + **engine_args, + ) else: + engine_args = filter_dataclass_kwargs(OmniEngineArgs, engine_args) + engine_args.pop("model", None) # Default to LLM engine stage_engine = OmniLLM(model=model, **engine_args) - finally: - # Release all locks by closing file descriptors - # Locks are automatically released when file descriptors are closed - # or when process dies - for lock_fd in lock_files: - try: - fcntl.flock(lock_fd, fcntl.LOCK_UN) - _os.close(lock_fd) - logger.debug("Released initialization lock (fd=%s)", lock_fd) - except (OSError, ValueError): - pass + logger.debug("Engine initialized") # Initialize OmniConnectors if configured - connectors = {} + connectors: dict[tuple[str, str], OmniConnectorBase] | None = {} if connectors_config: - built_connectors = build_stage_connectors( + connectors = build_stage_connectors( stage_id=stage_id, connectors_config=connectors_config, ) - if built_connectors is None: + if connectors is None: return - connectors = built_connectors # Signal readiness to orchestrator try: @@ -731,6 +869,7 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: continue batch_tasks: list[dict[str, Any]] = [task] + tasks_failed_to_add_to_batch: list[dict[str, Any]] = [] start_time = _time.time() if max_batch_size > 1: while len(batch_tasks) < max_batch_size: @@ -746,7 +885,20 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: if extra_type == OmniStageTaskType.PROFILER_STOP: out_q.put({"type": "profiler_result", "data": p_data}) continue - batch_tasks.append(extra) + # Ensure that all tasks have the same sampling params + # If no, put them in a temporary container and add back to queue + # This should be always true, because user only calls omni.generate() once and it blocks + # User can only pass one sampling param object, but the list of prompts are separated. + if task.get("sampling_params") != extra.get("sampling_params"): + logger.warning( + """In offline mode, expect all prompts in one `omni.generate()` call to share same sampling params""" # noqa: E501 # line too long + f"""However, prompt {task.get("engine_inputs")} has sampling params {task.get("sampling_params")}, """ # noqa: E501 # line too long + f"""whereas the prompt {extra.get("engine_inputs")} has sampling params {extra.get("sampling_params")}.""" # noqa: E501 # line too long + """The two tasks cannot be combined in one batch request.""" + ) + tasks_failed_to_add_to_batch.append(extra) + else: + batch_tasks.append(extra) end_time = _time.time() duration = end_time - start_time if duration > batch_timeout: @@ -761,9 +913,13 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: break else: continue + for task_to_readd in tasks_failed_to_add_to_batch: + in_q.put(task_to_readd) + # Ensure that the popped tasks are with identical sampling params. Take one of them. + batch_engine_sampling_params: OmniSamplingParams = batch_tasks[0]["sampling_params"] batch_request_ids: list[Any] = [] - batch_engine_inputs: list[Any] = [] + batch_engine_inputs: list[OmniPromptType] = [] _rx_bytes_by_rid: dict[Any, int] = {} _rx_decode_ms_by_rid: dict[Any, float] = {} _in_flight_ms_by_rid: dict[Any, float] = {} @@ -772,7 +928,7 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: try: sent_ts = float(t.get("sent_ts", None)) if isinstance(t, dict) else None if sent_ts is not None: - _in_flight_ms_by_rid[rid] = (_recv_dequeue_ts - sent_ts) * 1000.0 + _in_flight_ms_by_rid[rid] = max(0.0, (_recv_dequeue_ts - sent_ts) * 1000.0) else: _in_flight_ms_by_rid[rid] = 0.0 except Exception: @@ -785,6 +941,9 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: connectors=connectors, stage_id=stage_id, ) + # TODO: hack type annotation for now. + # A better way is to refine type annotation of connection and task/payloads, maybe using template types. + ein = cast(OmniPromptType | Sequence[OmniPromptType] | None, ein) if ein is None or _rx_metrics is None: raise RuntimeError( @@ -796,17 +955,15 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) batch_request_ids.append(rid) - if isinstance(ein, list): - batch_engine_inputs.extend(ein) - elif isinstance(ein, dict): - batch_engine_inputs.append(ein) - elif isinstance(ein, str): + + if isinstance(ein, (dict, str)): # For diffusion stage-0, ein might be a string prompt directly batch_engine_inputs.append(ein) + elif isinstance(ein, Sequence): + batch_engine_inputs.extend(ein) else: - # For other types (e.g., OmniTokensPrompt, TextPrompt), append as-is + # Other unknown types, append as-is batch_engine_inputs.append(ein) - sampling_params = batch_tasks[0]["sampling_params"] logger.debug( "Received batch size=%d, request_ids=%s", len(batch_tasks), @@ -814,61 +971,30 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: ) try: _batch_seq += 1 - gen_outputs: list[Any] = [] + gen_outputs: list[OmniRequestOutput | RequestOutput] = [] _gen_t0 = _time.time() if stage_type == "diffusion": - # For diffusion, batch_engine_inputs should be prompts (strings) - # Convert to list of strings if needed - prompts = [] - for ein in batch_engine_inputs: - if isinstance(ein, str): - prompts.append(ein) - elif isinstance(ein, dict) and "prompt" in ein: - prompts.append(ein["prompt"]) - elif hasattr(ein, "prompt"): - prompts.append(ein.prompt) - else: - prompts.append(str(ein)) - # Prepare diffusion kwargs from sampling parameters - diffusion_kwargs = prepare_sampling_params(sampling_params, "diffusion") + stage_engine = cast(OmniDiffusion, stage_engine) + batch_engine_sampling_params = cast(OmniDiffusionSamplingParams, batch_engine_sampling_params) # Diffusion generate returns results directly, not an iterator - diffusion_results = stage_engine.generate(prompts, **diffusion_kwargs) - # Convert to list format compatible with LLM outputs - # Ensure each result has a request_id for proper mapping - if isinstance(diffusion_results, list): - gen_outputs = diffusion_results - # Assign request_ids if not present - for idx, result in enumerate(gen_outputs): - if not hasattr(result, "request_id") or result.request_id is None: - if idx < len(batch_request_ids): - if hasattr(result, "request_id"): - result.request_id = batch_request_ids[idx] - else: - # Create a wrapper object if result doesn't support request_id - from types import SimpleNamespace - - wrapped = SimpleNamespace() - wrapped.request_id = batch_request_ids[idx] - wrapped.output = result - gen_outputs[idx] = wrapped - else: - gen_outputs = [diffusion_results] - # Assign request_id to single result - if len(batch_request_ids) > 0: - if hasattr(gen_outputs[0], "request_id"): - gen_outputs[0].request_id = batch_request_ids[0] - else: - from types import SimpleNamespace - - wrapped = SimpleNamespace() - wrapped.request_id = batch_request_ids[0] - wrapped.output = gen_outputs[0] - gen_outputs[0] = wrapped + diffusion_results = stage_engine.generate( + batch_engine_inputs, batch_engine_sampling_params, batch_request_ids + ) + gen_outputs.extend(diffusion_results) + # Assign request_ids if not present + for idx, result in enumerate(gen_outputs): + if not hasattr(result, "request_id") or result.request_id is None: + if idx < len(batch_request_ids): + result.request_id = batch_request_ids[idx] else: - # LLM engine: use vLLM native SamplingParams - llm_sampling_params = prepare_sampling_params(sampling_params, "llm") - for ro in stage_engine.generate(batch_engine_inputs, llm_sampling_params, use_tqdm=False): - gen_outputs.append(ro) + stage_engine = cast(OmniLLM, stage_engine) + batch_engine_sampling_params = cast(SamplingParams, batch_engine_sampling_params) + results = stage_engine.generate( + batch_engine_inputs, # type: ignore # silent complaints about list of subclassed TypedDict + batch_engine_sampling_params, + use_tqdm=False, + ) + gen_outputs.extend(results) _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 logger.debug(f"Generate done: batch={len(batch_tasks)}, req_ids={batch_request_ids}, gen_ms={_gen_ms:.1f}") @@ -877,7 +1003,7 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: req_to_outputs: dict[Any, list[Any]] = {rid: [] for rid in batch_request_ids} unmapped: list[Any] = [] for ro in gen_outputs: - rid = getattr(ro, "request_id", None) + rid = ro.request_id if rid in req_to_outputs: req_to_outputs[rid].append(ro) else: @@ -957,36 +1083,74 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: def _stage_worker_async_entry( - omni_stage: OmniStage, model: str, stage_payload: dict[str, Any], + in_q: mp.queues.Queue | ZmqQueue | str, + out_q: mp.queues.Queue | ZmqQueue | str, batch_timeout: int = 10, stage_init_timeout: int = 300, ) -> None: - asyncio.run(_stage_worker_async(omni_stage, model, stage_payload, batch_timeout, stage_init_timeout)) + asyncio.run(_stage_worker_async(model, stage_payload, in_q, out_q, batch_timeout, stage_init_timeout)) async def _stage_worker_async( - omni_stage: OmniStage, model: str, stage_payload: dict[str, Any], + in_q: mp.queues.Queue | ZmqQueue | str, + out_q: mp.queues.Queue | ZmqQueue | str, batch_timeout: int = 10, stage_init_timeout: int = 300, ) -> None: """Stage worker entry: device setup, LLM init, batching, SHM IPC.""" # Use local aliases to avoid conflicts with global imports in worker process + import multiprocessing as _mp import os as _os import time as _time + import zmq + + from vllm_omni.plugins import load_omni_general_plugins + + load_omni_general_plugins() + # IMPORTANT: Ensure vLLM's internal multiprocessing workers (e.g., GPUARWorker / + # GPUARModelRunner) are spawned with a fork-safe method. + if _os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn": + _os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + logger.info("[Stage-async] Set VLLM_WORKER_MULTIPROC_METHOD=spawn") + try: + _mp.set_start_method("spawn", force=True) + except RuntimeError: + pass + stage_id = stage_payload["stage_id"] engine_args = stage_payload.get("engine_args", {}) runtime_cfg = stage_payload.get("runtime", {}) shm_threshold_bytes = int(stage_payload.get("shm_threshold_bytes", 65536)) connectors_config = stage_payload.get("connectors_config", {}) stage_type = stage_payload.get("stage_type", "llm") + final_output = stage_payload.get("final_output", False) + final_output_type = stage_payload.get("final_output_type", None) - in_q = omni_stage._in_q - out_q = omni_stage._out_q + # Handle non-standard model directory structures (e.g., tokenizer in root, model in subdir) + model = _resolve_model_tokenizer_paths(model, engine_args) + + if stage_type != "diffusion": + _resolve_worker_cls(engine_args) + + # Resolve ZMQ queue endpoints if needed + zmq_ctx = None + if isinstance(in_q, str) or isinstance(out_q, str): + zmq_ctx = zmq.Context() + if isinstance(in_q, str): + in_q = create_zmq_queue(zmq_ctx, in_q, zmq.PULL) + if isinstance(out_q, str): + out_q = create_zmq_queue(zmq_ctx, out_q, zmq.PUSH) + # When using ZMQ (cross-node IPC), disable SHM so data is sent inline. + shm_threshold_bytes = sys.maxsize + logger.info( + "[Stage-%s] ZMQ transport detected; disabling SHM IPC (shm_threshold_bytes set to maxsize)", + stage_id, + ) # Aggregates for running average _agg_total_tokens = 0 @@ -998,9 +1162,9 @@ async def _stage_worker_async( # Device mapping device_type = None try: - from vllm_omni.utils import detect_device_type + from vllm_omni.platforms import current_omni_platform - device_type = detect_device_type() + device_type = current_omni_platform.device_type set_stage_devices(stage_id, runtime_cfg.get("devices"), device_type=device_type) except Exception as e: logger.warning("Device setup failed: %s", e) @@ -1016,134 +1180,34 @@ async def _stage_worker_async( return connectors = built_connectors - # Sequential initialization on the same device to avoid memory calculation errors - # when multiple instances start simultaneously - # For TP, we need to lock ALL devices that will be used by this stage - lock_files = [] - if device_type == "cuda": - try: - import torch - - if torch.cuda.is_available(): - # Get all parallel sizes from engine_args (defaults to 1) - tensor_parallel_size = engine_args.get("tensor_parallel_size", 1) - pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1) - data_parallel_size = engine_args.get("data_parallel_size", 1) - prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1) - - # Calculate total number of devices needed for this stage - # For a single stage worker in omni: - # - TP: splits model across GPUs (always needed) - # - PP: splits layers across stages, but each stage uses TP devices - # - DP: replicates model, but each replica uses TP devices - # - PCP: context parallelism, typically uses TP devices - # The number of devices per stage is determined by TP * PP * DP * PCP size - # (PP/DP/PCP are higher-level parallelism that don't add devices per stage) - num_devices_per_stage = ( - tensor_parallel_size * pipeline_parallel_size * data_parallel_size * prefill_context_parallel_size - ) - - # Get physical device IDs from CUDA_VISIBLE_DEVICES - # After set_stage_devices, CUDA_VISIBLE_DEVICES is set to physical device(s) - cuda_visible_devices = _os.environ.get("CUDA_VISIBLE_DEVICES") - physical_devices = [] - - if cuda_visible_devices: - try: - physical_devices = [int(x.strip()) for x in cuda_visible_devices.split(",") if x.strip()] - except (ValueError, IndexError): - pass - - if not physical_devices: - # Fallback: use logical device count if CUDA_VISIBLE_DEVICES not set - num_devices = torch.cuda.device_count() - physical_devices = list(range(num_devices)) - - # Determine which devices will be used (min of devices per stage and available devices) - num_devices_to_lock = min(num_devices_per_stage, len(physical_devices)) - devices_to_lock = physical_devices[:num_devices_to_lock] - - # Sort devices_to_lock to prevent deadlock (all processes acquire locks in same order) - devices_to_lock = sorted(devices_to_lock) - - logger.debug( - "Parallel config: TP=%d, PP=%d, DP=%d, PCP=%d; will lock %d devices: %s", - tensor_parallel_size, - pipeline_parallel_size, - data_parallel_size, - prefill_context_parallel_size, - num_devices_to_lock, - devices_to_lock, - ) - - # Acquire exclusive locks for all devices using fcntl.flock - # Locks are automatically released when process dies - wait_start = _time.time() - acquired_lock_fds = [] # Store file descriptors to keep locks alive - - for device_id in devices_to_lock: - lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock" - lock_acquired = False - - while not lock_acquired: - try: - # Open or create the lock file - lock_fd = _os.open(lock_file, _os.O_CREAT | _os.O_RDWR, 0o644) - - # Try to acquire exclusive lock (non-blocking first) - try: - fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) - # Successfully acquired lock - write PID - _os.ftruncate(lock_fd, 0) # Clear file - _os.write(lock_fd, f"{_os.getpid()}\n".encode()) - _os.fsync(lock_fd) # Ensure written to disk - lock_acquired = True - acquired_lock_fds.append(lock_fd) - logger.debug("Acquired exclusive lock for device %s", device_id) - except BlockingIOError: - # Lock is held by another process - _os.close(lock_fd) - - # Check if we've been waiting too long - if _time.time() - wait_start > stage_init_timeout: - logger.warning( - "Timeout waiting for device %s initialization lock, " - "proceeding anyway with timeout %s", - device_id, - stage_init_timeout, - ) - break - - # Wait a bit before retrying - _time.sleep(0.1) - except OSError as e: - # Other error - log and continue without lock - logger.debug( - "Failed to acquire lock for device %s: %s, continuing anyway", - device_id, - e, - ) - try: - _os.close(lock_fd) - except (OSError, NameError): - pass - break - - lock_files = acquired_lock_fds - except Exception as e: - logger.debug("Failed to set up sequential initialization lock: %s", e) - - # Init engine based on stage_type - logger.debug( - "[Stage-%s] Initializing %s engine with args keys=%s", - stage_id, - stage_type, - list(engine_args.keys()), - ) - try: + # Use sequential init locks only when NVML is unavailable + with _sequential_init_lock(engine_args, stage_init_timeout): + # Init engine based on stage_type + logger.debug( + "[Stage-%s] Initializing %s engine with args keys=%s", + stage_id, + stage_type, + list(engine_args.keys()), + ) + if engine_args.get("async_chunk", False): + logger.debug("[Stage-%s] Async chunk enabled, injecting connectors config", stage_id) + stage_connector_spec = {} + for v in connectors_config.values(): + stage_connector_spec = dict(v.get("spec", {})) + break + engine_args["stage_connector_spec"] = stage_connector_spec + engine_args["stage_id"] = stage_id if stage_type == "diffusion": # For diffusion, we need to extract diffusion-specific config + engine_args = filter_dataclass_kwargs(OmniDiffusionConfig, engine_args) od_config = _build_od_config(engine_args, model) + + # Inject omni config for worker to access stage info + if "omni_kv_config" not in od_config: + od_config["omni_kv_config"] = {} + od_config["omni_kv_config"]["stage_id"] = stage_id + od_config["omni_kv_config"]["engine_input_source"] = stage_payload.get("engine_input_source", []) + logger.debug(f"[Stage-%s] Initializing diffusion engine with config: {od_config}", stage_id) stage_engine = AsyncOmniDiffusion( model=model, @@ -1152,6 +1216,8 @@ async def _stage_worker_async( ) vllm_config = None # Diffusion doesn't use vllm_config else: + engine_args = filter_dataclass_kwargs(AsyncOmniEngineArgs, engine_args) + engine_args.pop("model", None) omni_engine_args = AsyncOmniEngineArgs(model=model, **engine_args) usage_context = UsageContext.OPENAI_API_SERVER vllm_config = omni_engine_args.create_engine_config(usage_context=usage_context) @@ -1159,26 +1225,17 @@ async def _stage_worker_async( vllm_config=vllm_config, usage_context=usage_context, engine_args=omni_engine_args, + disable_log_stats=bool( + engine_args.get("disable_log_stats", True) or getattr(omni_engine_args, "disable_log_stats", True) + ), ) - finally: - # Release all locks by closing file descriptors - # Locks are automatically released when file descriptors are closed - # or when process dies - for lock_fd in lock_files: - try: - fcntl.flock(lock_fd, fcntl.LOCK_UN) - _os.close(lock_fd) - logger.debug("Released initialization lock (fd=%s)", lock_fd) - except (OSError, ValueError): - pass - omni_stage.set_async_engine(stage_engine) - if hasattr(omni_stage.async_engine, "log_stats") and omni_stage.async_engine.log_stats: + if hasattr(stage_engine, "log_stats") and stage_engine.log_stats: async def _force_log(): try: while True: await asyncio.sleep(10.0) - await omni_stage.async_engine.do_log_stats() + await stage_engine.do_log_stats() except asyncio.CancelledError: pass @@ -1191,16 +1248,15 @@ async def _force_log(): await stage_engine.reset_mm_cache() logger.debug("[Stage-%s] Engine initialized", stage_id) - async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: + async def handle_profiler_task_async(task_type: OmniStageTaskType) -> dict: """Handle profiler task asynchronously for both LLM and diffusion stages.""" if task_type == OmniStageTaskType.PROFILER_START: if stage_type == "diffusion": try: - # Sync call is safe here — diffusion profiling is lightweight profile_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles") os.makedirs(profile_dir, exist_ok=True) trace_filename = f"stage_{stage_id}_diffusion_{int(time.time())}" - stage_engine.start_profile(trace_filename=trace_filename) + await stage_engine.start_profile(trace_filename=trace_filename) logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id) except Exception as e: logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e) @@ -1210,14 +1266,17 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: logger.info("[Stage-%s] vLLM profiler started", stage_id) except Exception as e: logger.warning("[Stage-%s] Failed to start vLLM profiler: %s", stage_id, e) + return {} elif task_type == OmniStageTaskType.PROFILER_STOP: + result_data: dict = {} if stage_type == "diffusion": try: - trace_files = stage_engine.stop_profile() + trace_files = await stage_engine.stop_profile() logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id) if trace_files: logger.info("Diffusion trace files: %s", trace_files) + result_data = trace_files except Exception as e: logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e) else: @@ -1226,6 +1285,8 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: logger.info("[Stage-%s] vLLM profiler stopped", stage_id) except Exception as e: logger.warning("[Stage-%s] Failed to stop vLLM profiler: %s", stage_id, e) + return result_data + return {} # Signal readiness to orchestrator and send vllm_config back to main process try: @@ -1243,6 +1304,14 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None: # Only add is_tracing_enabled for LLM engines if stage_type != "diffusion": stage_ready_payload["is_tracing_enabled"] = await stage_engine.is_tracing_enabled() + import pickle + try: + pickle.loads(pickle.dumps(stage_ready_payload)) + except Exception: + logger.warning("[Stage-%s] stage_ready_payload not picklable, dropping vllm_config/tokenizer", stage_id) + stage_ready_payload = {"type": "stage_ready", "stage_id": stage_id, + "is_tracing_enabled": stage_ready_payload.get("is_tracing_enabled"), + "vllm_config": None, "tokenizer": None} out_q.put(stage_ready_payload) except Exception as e: logger.warning("Failed to send stage ready signal: %s", e) @@ -1259,7 +1328,7 @@ async def generation_single_request(task: dict[str, Any]): try: sent_ts = float(task.get("sent_ts", None)) if isinstance(task, dict) else None if sent_ts is not None: - _in_flight_ms_by_rid[rid] = (_recv_dequeue_ts - sent_ts) * 1000.0 + _in_flight_ms_by_rid[rid] = max(0.0, (_recv_dequeue_ts - sent_ts) * 1000.0) else: _in_flight_ms_by_rid[rid] = 0.0 except Exception: @@ -1270,6 +1339,10 @@ async def generation_single_request(task: dict[str, Any]): connectors=connectors, stage_id=stage_id, ) + # TODO: hack type annotation for now. + # A better way is to refine type annotation of connection and task/payloads, maybe using template types. + ein = cast(OmniPromptType | Sequence[OmniPromptType] | None, ein) + if ein is None or _rx_metrics is None: raise RuntimeError( f"[Stage-{stage_id}] Missing connector payload for request {rid}. " @@ -1278,36 +1351,27 @@ async def generation_single_request(task: dict[str, Any]): _rx_decode_ms_by_rid[rid] = float(_rx_metrics.get("rx_decode_time_ms", 0.0)) _rx_bytes_by_rid[rid] = int(_rx_metrics.get("rx_transfer_bytes", 0)) - sampling_params = task["sampling_params"] logger.debug("Received batch size=1, request_ids=%s", rid) _gen_t0 = _time.time() - if isinstance(ein, list): + if isinstance(ein, Sequence) and not isinstance(ein, str): + if len(ein) == 0: + logger.info("[Stage-%s] Skipping request %s: no engine inputs", stage_id, rid) + out_q.put({"request_id": rid, "stage_id": stage_id, "skipped": True}) + return ein = ein[0] if stage_type == "diffusion": - # For diffusion, ein should be prompts (strings) - # Convert to string if needed - if isinstance(ein, str): - prompt = ein - elif isinstance(ein, dict) and "prompt" in ein: - prompt = ein["prompt"] - elif hasattr(ein, "prompt"): - prompt = ein.prompt - else: - prompt = str(ein) - - # Prepare diffusion kwargs from sampling parameters - diffusion_kwargs = prepare_sampling_params(sampling_params, "diffusion") + diffusion_sampling_params = cast(OmniDiffusionSamplingParams, task["sampling_params"]) # AsyncOmniDiffusion.generate returns a single result, not an async generator - gen_output = await stage_engine.generate(prompt=prompt, request_id=rid, **diffusion_kwargs) + gen_output = await cast(AsyncOmniDiffusion, stage_engine).generate(ein, diffusion_sampling_params, rid) _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 await generation_out_q.put((rid, gen_output, _gen_ms)) else: - # LLM stages: ensure using SamplingParams - llm_sampling_params = prepare_sampling_params(sampling_params, "llm") + ein = cast(PromptType, ein) + llm_sampling_params: SamplingParams = task["sampling_params"] gen_output = None - async for res in stage_engine.generate(ein, llm_sampling_params, rid): + async for res in cast(AsyncLLM, stage_engine).generate(ein, llm_sampling_params, rid): gen_output = res _gen_t1 = _time.time() _gen_ms = (_gen_t1 - _gen_t0) * 1000.0 @@ -1336,7 +1400,10 @@ async def generation_single_request(task: dict[str, Any]): rid = task["request_id"] asyncio.create_task(stage_engine.abort(rid)) elif is_profiler_task(task_type): - await handle_profiler_task_async(task_type) + profiler_data = await handle_profiler_task_async(task_type) + # Send result back to orchestrator for STOP command + if task_type == OmniStageTaskType.PROFILER_STOP: + out_q.put({"type": "profiler_result", "data": profiler_data}) else: asyncio.create_task(generation_single_request(task)) @@ -1384,7 +1451,7 @@ async def generation_single_request(task: dict[str, Any]): batch_request_ids, batch_request_outputs, _gen_ms_list, batch_metrics ): try: - r_outputs = [output] + r_outputs = [output_strip(output, final_output, final_output_type)] use_shm, payload = maybe_dump_to_shm(r_outputs, shm_threshold_bytes) if use_shm: out_q.put( @@ -1447,13 +1514,11 @@ def make_request_stats( rx_transfer_bytes: int, rx_in_flight_time_ms: float, ): - from vllm_omni.entrypoints.log_utils import ( - StageRequestMetrics, - ) + from vllm_omni.metrics import StageRequestStats num_tokens_in = count_prompt_tokens_from_outputs(req_output) num_tokens_out = count_tokens_from_outputs(req_output) - return StageRequestMetrics( + return StageRequestStats( num_tokens_in=num_tokens_in, num_tokens_out=num_tokens_out, stage_gen_time_ms=stage_gen_time_ms, @@ -1467,6 +1532,35 @@ def make_request_stats( def make_stage_stats(_agg_total_tokens: int, _agg_total_gen_time_ms: float): - from vllm_omni.entrypoints.log_utils import StageStats + from vllm_omni.metrics import StageStats + + return StageStats(total_token=_agg_total_tokens, total_gen_time_ms=_agg_total_gen_time_ms) + + +def output_strip(r_output: RequestOutput | OmniRequestOutput, final_output: bool, final_output_type: str | None): + """ + Strip unnecessary multimodal outputs from stages results, + in order to: + - reduce memory usage + - reduce transfer & serialization overhead + """ + + # check multimodal data is required by stage output config. + if final_output and final_output_type != "text": + return r_output + + # If the request has already finished, should not be altered. + if getattr(r_output, "finished", False): + return r_output + + mm_output = getattr(r_output, "multimodal_output", None) + if mm_output is not None: + r_output.multimodal_output = {} + + outputs = getattr(r_output, "outputs", None) + if outputs is not None: + for out in outputs: + if getattr(out, "multimodal_output", None): + out.multimodal_output = {} - return StageStats(total_token=_agg_total_tokens, total_gen_time=_agg_total_gen_time_ms) + return r_output diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 6fb9750ccc1..56cc9dd1011 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -3,44 +3,50 @@ import json import time import uuid -from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Callable from datetime import datetime, timedelta, timezone from io import BytesIO -from typing import TYPE_CHECKING, Any, Final, Optional +from typing import TYPE_CHECKING, Any, Final, Optional, cast import jinja2 +import torch from fastapi import Request from PIL import Image from pydantic import TypeAdapter +from vllm.renderers import BaseRenderer + +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.openai.protocol.chat_completion import OmniChatCompletionResponse +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt try: import soundfile except ImportError: soundfile = None + from openai.types.chat.chat_completion_audio import ChatCompletionAudio as OpenAIChatCompletionAudio from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, get_history_tool_calls_cnt, make_tool_call_id, - resolve_chat_template_content_format, ) -from vllm.entrypoints.harmony_utils import get_streamable_parser_for_assistant, parse_chat_output -from vllm.entrypoints.openai.protocol import ( +from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, +) +from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat +from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, + ErrorInfo, ErrorResponse, FunctionCall, FunctionDefinition, @@ -49,39 +55,38 @@ ToolCall, UsageInfo, ) -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_engine import ( - ChatLikeRequest, - EngineTokensPrompt, - RequestPrompt, - ResponsesRequest, - TextTokensPrompt, - clamp_prompt_logprobs, - is_list_of, +from vllm.entrypoints.openai.engine.serving import ChatLikeRequest, clamp_prompt_logprobs +from vllm.entrypoints.openai.parser.harmony_utils import ( + get_streamable_parser_for_assistant, + parse_chat_output, ) -from vllm.entrypoints.openai.tool_parsers import ToolParser -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import should_include_usage from vllm.inputs.data import PromptType from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput +from vllm.reasoning import ReasoningParser +from vllm.renderers import merge_kwargs +from vllm.renderers.inputs import TokPrompt from vllm.sampling_params import SamplingParams from vllm.tokenizers import TokenizerLike +from vllm.tokenizers import TokenizerLike as AnyTokenizer from vllm.tokenizers.mistral import ( MistralTokenizer, maybe_serialize_tool_calls, truncate_tool_call_ids, validate_request_params, ) -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.utils.collection_utils import as_list -from vllm_omni.entrypoints.chat_utils import parse_chat_messages_futures from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio +from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id from vllm_omni.outputs import OmniRequestOutput if TYPE_CHECKING: @@ -165,7 +170,21 @@ async def create_chat_completion( model_name = self.models.model_name(lora_request) - tokenizer = await self.engine_client.get_tokenizer() + renderer = self.renderer + tokenizer = renderer.get_tokenizer() + if tokenizer is None: + tokenizer = await self.engine_client.get_tokenizer() + + reasoning_parser: ReasoningParser | None = None + if self.reasoning_parser_cls: + chat_template_kwargs = self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) + reasoning_parser = self.reasoning_parser_cls( + tokenizer, + chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] + ) tool_parser = self.tool_parser @@ -177,57 +196,91 @@ async def create_chat_completion( truncate_tool_call_ids(request) validate_request_params(request) - if ( - request.tool_choice == "auto" - and not (self.enable_auto_tools and tool_parser is not None) - and not isinstance(tokenizer, MistralTokenizer) - and not self.use_harmony + # Check if tool parsing is unavailable (common condition) + tool_parsing_unavailable = ( + tool_parser is None and not isinstance(tokenizer, MistralTokenizer) and not self.use_harmony + ) + + # Validate tool_choice when tool parsing is required but unavailable + if tool_parsing_unavailable and request.tool_choice not in ( + None, + "none", ): - # for hf tokenizers, "auto" tools requires - # --enable-auto-tool-choice and --tool-call-parser - return self.create_error_response( - '"auto" tool choice requires --enable-auto-tool-choice and --tool-call-parser to be set' - ) + if request.tool_choice == "auto" and not self.enable_auto_tools: + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + '"auto" tool choice requires --enable-auto-tool-choice and --tool-call-parser to be set' + ) + elif request.tool_choice != "auto": + # "required" or named tool requires tool parser + return self.create_error_response( + f'tool_choice="{request.tool_choice}" requires --tool-call-parser to be set' + ) if request.tools is None or (request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none): tool_dicts = None else: tool_dicts = [tool.model_dump() for tool in request.tools] - # Common case. - request_chat_template = request.chat_template - chat_template_kwargs = request.chat_template_kwargs - if not self.trust_request_chat_template and ( - request_chat_template is not None - or (chat_template_kwargs and chat_template_kwargs.get("chat_template") is not None) - ): - return self.create_error_response( - "Chat template is passed with request, but --trust-request-chat-template is not set. " - "Refused request with untrusted chat template." + if not self.use_harmony: + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, ) - ( - conversation, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( - request, - tokenizer, - request.messages, - chat_template=request_chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - tool_dicts=tool_dicts, - documents=request.documents, - chat_template_kwargs=request.chat_template_kwargs, - tool_parser=tool_parser, - add_special_tokens=request.add_special_tokens, - ) + if error_check_ret is not None: + return error_check_ret + + chat_template_kwargs = request.chat_template_kwargs or {} + chat_template_kwargs.update(reasoning_effort=request.reasoning_effort) + + # Merge chat_template_kwargs with defaults + merged_template_kwargs = self._prepare_extra_chat_template_kwargs( + chat_template_kwargs, + self.default_chat_template_kwargs, + ) + conversation, engine_prompts = await self._preprocess_chat( + request, + request.messages, + default_template=request.chat_template or self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=merged_template_kwargs, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + # OMNI: Additional parameters + renderer=renderer, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + documents=getattr(request, "documents", None), + add_special_tokens=request.add_special_tokens, + ) + else: + should_include_tools = tool_dicts is not None + conversation, engine_prompts = self._make_request_with_harmony(request, should_include_tools) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") + # Zero-shot TTS: extract last input_audio bytes as speaker reference for S2S. + _ref_audio_b64 = None + for _msg in request.messages: + _content = getattr(_msg, 'content', None) or (_msg.get('content') if isinstance(_msg, dict) else None) + if isinstance(_content, list): + for _part in _content: + _ptype = (_part.get('type') if isinstance(_part, dict) else getattr(_part, 'type', None)) + if _ptype == 'input_audio': + _ia = (_part.get('input_audio') if isinstance(_part, dict) else getattr(_part, 'input_audio', None)) + if _ia is not None: + _data = (_ia.get('data') if isinstance(_ia, dict) else getattr(_ia, 'data', None)) + if _data: + _ref_audio_b64 = _data + if _ref_audio_b64 is not None: + for _ep in engine_prompts: + if isinstance(_ep, dict): + _ep['ref_audio_b64'] = _ref_audio_b64 + request_id = f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" request_metadata = RequestResponseMetadata(request_id=request_id) @@ -239,6 +292,92 @@ async def create_chat_completion( output_modalities if output_modalities is not None else self.engine_client.output_modalities ) + # Omni multistage image generation: Stage-0 (AR) should receive a clean + # text prompt (and optional conditioning image/size) so the model's own + # processor can construct the correct inputs. + # If we pass pre-tokenized chat-template ids, GLM-Image can become + # effectively unconditioned and produce nonsense images. + # Skip if audio input is present (A2T/S2T needs full chat template prompt). + # Skip if Stage-0 is an LLM (e.g. HCX Omni thinker) - it needs the + # pre-tokenized chat-template prompt, not a bare text prompt. + _has_audio_input = any( + "audio" in (ep.get("multi_modal_data") or {}) + for ep in engine_prompts + ) + _stage0_is_llm = ( + hasattr(self.engine_client, "stage_list") + and self.engine_client.stage_list + and getattr(self.engine_client.stage_list[0], "stage_type", None) == "llm" + ) + if request.modalities and ("image" in request.modalities) and not getattr(request, "continue_final_message", False) and not _has_audio_input and not _stage0_is_llm: + try: + messages_as_dicts: list[dict[str, Any]] = [] + for msg in request.messages: + if hasattr(msg, "model_dump"): + messages_as_dicts.append(msg.model_dump()) + elif isinstance(msg, dict): + messages_as_dicts.append(msg) + else: + messages_as_dicts.append( + { + "role": getattr(msg, "role", "user"), + "content": getattr(msg, "content", ""), + } + ) + extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images(messages_as_dicts) + if not extracted_prompt: + return self.create_error_response("No text prompt found in messages") + + extra_body = getattr(request, "extra_body", None) or {} + height = extra_body.get("height") + width = extra_body.get("width") + if "size" in extra_body: + try: + size_str = extra_body["size"] + if isinstance(size_str, str) and "x" in size_str.lower(): + w, h = size_str.lower().split("x") + width, height = int(w), int(h) + except Exception: + pass + negative_prompt = extra_body.get("negative_prompt") + + engine_prompt_image: dict[str, Any] | None = None + if reference_images: + # Best-effort decode first reference image for i2i. + try: + img_bytes = base64.b64decode(reference_images[0]) + img = Image.open(BytesIO(img_bytes)) + engine_prompt_image = {"image": img} + except Exception: + engine_prompt_image = None + + # Override the prompts produced by chat-template preprocessing. + tprompt: OmniTextPrompt = {"prompt": extracted_prompt} + if negative_prompt is not None: + tprompt["negative_prompt"] = negative_prompt + # GLM-Image's _call_hf_processor expects target_h/target_w in mm_processor_kwargs + mm_processor_kwargs: dict[str, Any] = {} + if height is not None: + mm_processor_kwargs["target_h"] = height + if width is not None: + mm_processor_kwargs["target_w"] = width + if mm_processor_kwargs: + tprompt["mm_processor_kwargs"] = mm_processor_kwargs + if engine_prompt_image is not None: + tprompt["multi_modal_data"] = engine_prompt_image + + engine_prompts = [tprompt] + # Store height/width for applying to diffusion stage sampling params later + _image_gen_height = height + _image_gen_width = width + except Exception as e: + logger.warning("Failed to build image-generation prompt for omni multistage: %s", e) + _image_gen_height = None + _image_gen_width = None + else: + _image_gen_height = None + _image_gen_width = None + # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] try: @@ -249,29 +388,32 @@ async def create_chat_completion( # Use standard OpenAI API parameters for comprehension stage sampling_params_list = self._build_sampling_params_list_from_request(request) + # Apply user-specified height/width to diffusion stage(s) for image generation + if _image_gen_height is not None or _image_gen_width is not None: + for idx, sp in enumerate(sampling_params_list): + # Diffusion stages typically have height/width attributes + if hasattr(sp, "height") and _image_gen_height is not None: + sp.height = _image_gen_height + if hasattr(sp, "width") and _image_gen_width is not None: + sp.width = _image_gen_width + self._log_inputs( request_id, - request_prompts[i], + engine_prompt, params_list=sampling_params_list, lora_request=lora_request, ) - trace_headers = None if raw_request is None else await self._get_trace_headers(raw_request.headers) - generator = self.engine_client.generate( prompt=engine_prompt, request_id=request_id, sampling_params_list=sampling_params_list, output_modalities=output_modalities, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, ) generators.append(generator) except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) assert len(generators) == 1 (result_generator,) = generators @@ -286,6 +428,7 @@ async def create_chat_completion( conversation, tokenizer, request_metadata, + reasoning_parser, ) try: @@ -297,75 +440,82 @@ async def create_chat_completion( conversation, tokenizer, request_metadata, + reasoning_parser, ) except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) async def _preprocess_chat( self, request: ChatLikeRequest | ResponsesRequest, - tokenizer: TokenizerLike, messages: list[ChatCompletionMessageParam], - chat_template: str | None, - chat_template_content_format: ChatTemplateContentFormatOption, + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + default_template_kwargs: dict[str, Any] | None = None, + tool_dicts: list[dict[str, Any]] | None = None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, + # OMNI: Additional parameters for backward compatibility + renderer: BaseRenderer | None = None, add_generation_prompt: bool = True, continue_final_message: bool = False, - tool_dicts: list[dict[str, Any]] | None = None, documents: list[dict[str, str]] | None = None, - chat_template_kwargs: dict[str, Any] | None = None, - tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, add_special_tokens: bool = False, - ) -> tuple[ - list[ConversationMessage], - Sequence[RequestPrompt], - list[EngineTokensPrompt], - ]: - model_config = self.model_config - - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tool_dicts, - chat_template_content_format, - tokenizer, - model_config=model_config, - ) - conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( - messages, - model_config, - tokenizer, - content_format=resolved_content_format, - mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), + ) -> tuple[list[ConversationMessage], list[TokPrompt]]: + if renderer is None: + renderer = self.renderer + + # Keep OMNI compatibility args wired while delegating rendering + # to the upstream async renderer pipeline. + default_template_kwargs = merge_kwargs( + default_template_kwargs, + dict( + tools=tool_dicts, + documents=documents, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + add_special_tokens=add_special_tokens, + tokenize=isinstance(renderer.tokenizer, MistralTokenizer), + ), ) - _chat_template_kwargs: dict[str, Any] = dict( - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, - tools=tool_dicts, - documents=documents, + tok_params = request.build_tok_params(self.model_config) + chat_params = request.build_chat_params( + default_template, + default_template_content_format, + ).with_defaults(default_template_kwargs) + + (conversation,), (engine_prompt,) = await renderer.render_chat_async( + [messages], + chat_params, + tok_params, + prompt_extras={ + k: v for k in ("mm_processor_kwargs", "cache_salt") if (v := getattr(request, k, None)) is not None + }, ) - _chat_template_kwargs.update(chat_template_kwargs or {}) - - request_prompt: str | list[int] - - if tokenizer is None: - request_prompt = "placeholder" - elif isinstance(tokenizer, MistralTokenizer): - request_prompt = apply_mistral_chat_template( - tokenizer, - messages=messages, - **_chat_template_kwargs, - ) - else: - request_prompt = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) - mm_data = await mm_data_future + # OMNI: When use_audio_in_video=True, the upstream renderer does not + # extract audio from video. We do it here after rendering so that the + # audio data is present in multi_modal_data before the engine processes + # the request. + mm_proc_kw = getattr(request, "mm_processor_kwargs", None) or {} + if mm_proc_kw.get("use_audio_in_video", False) and isinstance(engine_prompt, dict): + mm_data = engine_prompt.get("multi_modal_data") + if mm_data is not None and "video" in mm_data and "audio" not in mm_data: + from vllm_omni.entrypoints.chat_utils import extract_audio_from_video_async + + video_urls: list[str] = [] + for msg in messages: + for part in msg.get("content") or []: + if isinstance(part, dict) and part.get("type") == "video_url": + url = part.get("video_url", {}).get("url") + if url: + video_urls.append(url) + + if video_urls: + audios = await asyncio.gather(*(extract_audio_from_video_async(u) for u in video_urls)) + engine_prompt.setdefault("multi_modal_data", {})["audio"] = list(audios) + + tokenizer = renderer.get_tokenizer() # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser @@ -383,41 +533,38 @@ async def _preprocess_chat( request=request ) - if tokenizer is None: - assert isinstance(request_prompt, str), ( - "Prompt has to be a string", - "when the tokenizer is not initialised", - ) - prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) - elif isinstance(request_prompt, str): - prompt_inputs = await self._tokenize_prompt_input_async( - request, - tokenizer, - request_prompt, - add_special_tokens=add_special_tokens, - ) - else: - # For MistralTokenizer - assert is_list_of(request_prompt, int), "Prompt has to be either a string or a list of token ids" - prompt_inputs = TextTokensPrompt( - prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt, - ) - - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) - if mm_data is not None: - engine_prompt["multi_modal_data"] = mm_data - - if mm_uuids is not None: - engine_prompt["multi_modal_uuids"] = mm_uuids + # Preserve a clean text prompt for downstream stages (e.g., GLM-Image diffusion). + # For /v1/chat/completions, `request_prompt` is often the rendered chat template. + # Diffusion models generally want the raw user caption instead. + # Skip if audio is already in mm_data (A2T request needs full chat template). + output_modalities = getattr(self.engine_client, "output_modalities", None) + _has_audio_mm = "audio" in (engine_prompt.get("multi_modal_data") or {}) + if output_modalities and ("image" in output_modalities) and not continue_final_message and not _has_audio_mm: + messages_as_dicts: list[dict[str, Any]] = [] + for msg in messages: + if hasattr(msg, "model_dump"): + messages_as_dicts.append(msg.model_dump()) + elif isinstance(msg, dict): + messages_as_dicts.append(msg) + else: + messages_as_dicts.append( + { + "role": getattr(msg, "role", "user"), + "content": getattr(msg, "content", ""), + } + ) + extracted_prompt, _ = self._extract_diffusion_prompt_and_images(messages_as_dicts) + if extracted_prompt: + engine_prompt["prompt"] = extracted_prompt - if request.mm_processor_kwargs is not None: - engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + mm_processor_kwargs = getattr(request, "mm_processor_kwargs", None) + if mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs if hasattr(request, "cache_salt") and request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt - return conversation, [request_prompt], [engine_prompt] + return conversation, [engine_prompt] def _to_sampling_params_list(self, sampling_params_list: list[dict]) -> list[SamplingParams]: final_sampling_params_list = [] @@ -443,9 +590,13 @@ def _get_comprehension_stage_index(self) -> int: _OPENAI_SAMPLING_FIELDS: set[str] = { "temperature", "top_p", + "top_k", "max_tokens", + "min_tokens", "seed", + "ignore_eos", "stop", + "stop_token_ids", "frequency_penalty", "presence_penalty", } @@ -513,29 +664,20 @@ def _build_sampling_params_list_from_request( def _log_inputs( self, request_id: str, - inputs: RequestPrompt | PromptType, + inputs: PromptType | TokPrompt, params_list: list[SamplingParams] | None, lora_request: LoRARequest | None, ) -> None: if self.request_logger is None: return - prompt, prompt_token_ids, prompt_embeds = None, None, None - if isinstance(inputs, str): - prompt = inputs - elif isinstance(inputs, list): - prompt_token_ids = inputs - else: - prompt = getattr(inputs, "prompt", None) - prompt_token_ids = getattr(inputs, "prompt_token_ids", None) - - logger.info( - "Received request %s: prompt: %r, params_list: %s, prompt_token_ids: %s, prompt_embeds shape: %s, lora_request: %s.", # noqa: E501 + components = self._extract_prompt_components(inputs) + self.request_logger.log_inputs( request_id, - prompt, - params_list, - prompt_token_ids, - prompt_embeds.shape if prompt_embeds is not None else None, - lora_request, + components.text, + components.token_ids, + components.embeds, + params=params_list, + lora_request=lora_request, ) async def chat_completion_stream_generator( @@ -547,6 +689,7 @@ async def chat_completion_stream_generator( conversation: list[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, + reasoning_parser: ReasoningParser | None = None, ): created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" @@ -588,7 +731,7 @@ async def chat_completion_stream_generator( # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. - if tool_choice_auto or self.reasoning_parser: + if tool_choice_auto or reasoning_parser: # These are only required in "auto" tool choice case all_previous_token_ids = [[]] * num_choices # For reasoning parser and tool call all enabled @@ -596,19 +739,6 @@ async def chat_completion_stream_generator( reasoning_end_arr = [False] * num_choices else: all_previous_token_ids = None - - try: - if self.reasoning_parser: - reasoning_parser = self.reasoning_parser( - tokenizer, - chat_template_kwargs=request.chat_template_kwargs, # type: ignore - ) - except RuntimeError as e: - logger.exception("Error in reasoning parser creation.") - data = self.create_streaming_error_response(str(e)) - yield f"data: {data}\n\n" - yield "data: [DONE]\n\n" - return # Prepare the tool parser if it's needed try: if tool_choice_auto and self.tool_parser: @@ -617,7 +747,7 @@ async def chat_completion_stream_generator( tool_parsers = [None] * num_choices except Exception as e: logger.exception("Error in tool parser creation.") - data = self.create_streaming_error_response(str(e)) + data = self.create_streaming_error_response(e) yield f"data: {data}\n\n" yield "data: [DONE]\n\n" return @@ -625,6 +755,7 @@ async def chat_completion_stream_generator( stream_options = request.stream_options include_usage, include_continuous_usage = should_include_usage(stream_options, self.enable_force_include_usage) + last_metrics: dict[str, Any] | None = None try: async for omni_res in result_generator: final_output_type = omni_res.final_output_type @@ -633,22 +764,25 @@ async def chat_completion_stream_generator( logger.warning(f"final output type: {final_output_type} is not needed by the request") continue + if omni_res.metrics: + last_metrics = omni_res.metrics + if res.prompt_token_ids is not None: num_prompt_tokens = len(res.prompt_token_ids) if res.encoder_prompt_token_ids is not None: num_prompt_tokens += len(res.encoder_prompt_token_ids) + # Initialize role before conditional blocks to avoid UnboundLocalError + # when handling audio/image responses + role = self.get_chat_request_role(request) + # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). if first_iteration_dict[final_output_type] and final_output_type == "text": num_cached_tokens = res.num_cached_tokens - # Send first response for each request.n (index) with - # the role - role = self.get_chat_request_role(request) - - # NOTE num_choices defaults to 1 so this usually executes - # once per request + # Send first response for each choice with role + # NOTE: num_choices defaults to 1 so this usually executes once per request for i in range(num_choices): choice_data = ChatCompletionResponseStreamChoice( index=i, @@ -760,7 +894,7 @@ async def chat_completion_stream_generator( delta_message: DeltaMessage | None # just update previous_texts and previous_token_ids - if tool_choice_auto or self.reasoning_parser: + if tool_choice_auto or reasoning_parser: assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -827,7 +961,7 @@ async def chat_completion_stream_generator( # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: if ( - self.reasoning_parser + reasoning_parser and not reasoning_end_arr[i] and not reasoning_parser.is_reasoning_end(previous_token_ids) ): @@ -857,7 +991,7 @@ async def chat_completion_stream_generator( current_text = "" else: # Just to add remaining `content` - if self.reasoning_parser: + if reasoning_parser: delta_text = previous_text + delta_text current_text = "" @@ -893,14 +1027,14 @@ async def chat_completion_stream_generator( output_token_ids = as_list(output.token_ids) if ( - self.reasoning_parser is not None + reasoning_parser is not None and not reasoning_end_arr[i] and res.prompt_token_ids and reasoning_parser.is_reasoning_end(res.prompt_token_ids) ): reasoning_end_arr[i] = True - if self.reasoning_parser and not reasoning_end_arr[i]: + if reasoning_parser and not reasoning_end_arr[i]: delta_message = reasoning_parser.extract_reasoning_streaming( previous_text, current_text, @@ -939,7 +1073,7 @@ async def chat_completion_stream_generator( # handle streaming deltas for tools with "auto" tool choice # and reasoning parser - elif tool_choice_auto and self.reasoning_parser: + elif tool_choice_auto and reasoning_parser: assert tool_parser is not None assert reasoning_parser is not None assert added_content_delta_arr is not None @@ -1020,7 +1154,7 @@ async def chat_completion_stream_generator( tools_streamed[i] = True # when only reasoning - elif self.reasoning_parser: + elif reasoning_parser: delta_message = reasoning_parser.extract_reasoning_streaming( previous_text, current_text, @@ -1034,7 +1168,7 @@ async def chat_completion_stream_generator( delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if (tool_choice_auto or self.reasoning_parser) and not self.use_harmony: + if (tool_choice_auto or reasoning_parser) and not self.use_harmony: assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -1052,10 +1186,9 @@ async def chat_completion_stream_generator( # wasn't ready to send a token, then # get the next token without streaming a chunk if delta_message is None: - if output.finish_reason is None: + if output.finish_reason is None and not request.return_token_ids: continue - else: - delta_message = DeltaMessage() + delta_message = DeltaMessage() # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: @@ -1171,6 +1304,7 @@ async def chat_completion_stream_generator( choices=[choice_data], model=model_name, modality=final_output_type, + metrics=omni_res.metrics, ) # handle usage stats if requested & if continuous @@ -1186,22 +1320,67 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" elif final_output_type == "audio": - choices_data = self._create_audio_choice(omni_res, role, request, stream=True) - chunk = OmniChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=choices_data, - model=model_name, - modality=final_output_type, - ) - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - total_tokens=num_prompt_tokens, - ) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" + role = self.get_chat_request_role(request) + # Stream audio as PCM chunks (200ms each @ 24kHz). + # BigVGAN decoder is non-causal: full audio is generated + # first, then split for streaming delivery to the client. + import numpy as np + _AUDIO_CHUNK_SAMPLES = 4800 # 200ms @ 24kHz + _final_res = omni_res.request_output + if _final_res is not None and _final_res.outputs: + _audio_data = _final_res.outputs[0].multimodal_output.get("audio") + else: + _audio_data = omni_res.multimodal_output.get("audio") + # Normalize audio to float32 tensor for uniform PCM chunking. + # HyperCLOVAXAudioPipeline returns WAV bytes (with header); parse to strip it. + # Qwen3-Omni returns float tensors directly. + import soundfile as _sf, io as _io + if isinstance(_audio_data, bytes): + _arr, _ = _sf.read(_io.BytesIO(_audio_data)) + if _arr.ndim > 1: + _arr = _arr.mean(axis=1) + _audio_tensor = torch.from_numpy(_arr.astype(np.float32)) + elif isinstance(_audio_data, list): + _audio_tensor = torch.cat(_audio_data, dim=-1).float().detach().cpu() + else: + _audio_tensor = _audio_data.float().detach().cpu() + _audio_tensor = _audio_tensor.flatten() + _chunks = list(torch.split(_audio_tensor, _AUDIO_CHUNK_SAMPLES)) + _stream_outputs = (_final_res.outputs if (_final_res is not None and _final_res.outputs) + else [None]) + for _chunk_idx, _wav_chunk in enumerate(_chunks): + _pcm = (_wav_chunk.numpy() * 32767.0).clip(-32768, 32767).astype(np.int16) + _pcm_b64 = base64.b64encode(_pcm.tobytes()).decode("ascii") + _is_last_chunk = _chunk_idx == len(_chunks) - 1 + _stream_choices = [] + for _so_idx, output in enumerate(_stream_outputs): + _stream_choices.append( + ChatCompletionResponseStreamChoice( + index=output.index if output is not None else _so_idx, + delta=DeltaMessage( + role=role if _chunk_idx == 0 else None, + content=_pcm_b64, + ), + logprobs=None, + finish_reason="stop" if _is_last_chunk else None, + stop_reason=(output.stop_reason if (output is not None and _is_last_chunk) else None), + ) + ) + _audio_chunk_resp = OmniChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=_stream_choices, + model=model_name, + modality="audio", + ) + if _is_last_chunk: + _audio_chunk_resp.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) + yield f"data: {_audio_chunk_resp.model_dump_json(exclude_unset=True)}\n\n" else: logger.warning(f"Unsupported streaming final output type: {final_output_type}") @@ -1219,13 +1398,14 @@ async def chat_completion_stream_generator( if self.enable_prompt_tokens_details and num_cached_tokens: final_usage.prompt_tokens_details = PromptTokenUsageInfo(cached_tokens=num_cached_tokens) - final_usage_chunk = ChatCompletionStreamResponse( + final_usage_chunk = OmniChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[], model=model_name, usage=final_usage, + metrics=last_metrics, ) final_usage_data = final_usage_chunk.model_dump_json(exclude_unset=True, exclude_none=True) yield f"data: {final_usage_data}\n\n" @@ -1257,9 +1437,8 @@ async def chat_completion_stream_generator( ) except Exception as e: - # TODO: Use a vllm-specific Validation Error logger.exception("Error in chat completion stream generator.") - data = self.create_streaming_error_response(str(e)) + data = self.create_streaming_error_response(e) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" @@ -1273,7 +1452,8 @@ async def chat_completion_full_generator( conversation: list[ConversationMessage], tokenizer: TokenizerLike, request_metadata: RequestResponseMetadata, - ) -> ErrorResponse | ChatCompletionResponse: + reasoning_parser: ReasoningParser | None = None, + ) -> ErrorResponse | OmniChatCompletionResponse: created_time = int(time.time()) final_res: RequestOutput | None = None @@ -1284,8 +1464,7 @@ async def chat_completion_full_generator( except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) assert final_outputs is not None @@ -1296,12 +1475,23 @@ async def chat_completion_full_generator( prompt_logprobs = None prompt_token_ids = None kv_transfer_params = None + response_metrics: dict[str, Any] | None = None + + # Build requested modalities set for filtering + requested_modalities = ( + set(request.modalities) if hasattr(request, "modalities") and request.modalities else None + ) for omni_outputs in final_outputs: choices_data = [] if omni_outputs.request_output is not None and not getattr(omni_outputs.request_output, "finished", False): continue + # Filter outputs based on requested modalites + if requested_modalities is not None and omni_outputs.final_output_type not in requested_modalities: + logger.warning(f"final output type: {omni_outputs.final_output_type} is not needed by the request") + continue + if omni_outputs.final_output_type == "text": ( choices_data, @@ -1309,7 +1499,14 @@ async def chat_completion_full_generator( prompt_logprobs, prompt_token_ids, kv_transfer_params, - ) = self._create_text_choice(request, omni_outputs, tokenizer, conversation, role) + ) = self._create_text_choice( + request, + omni_outputs, + tokenizer, + conversation, + role, + reasoning_parser, + ) elif omni_outputs.final_output_type == "audio": choices_data = self._create_audio_choice(omni_outputs, role, request, stream=False) elif omni_outputs.final_output_type == "image": @@ -1317,9 +1514,11 @@ async def chat_completion_full_generator( else: logger.warning(f"Unsupported final output type: {omni_outputs.final_output_type}") continue + if omni_outputs.metrics: + response_metrics = omni_outputs.metrics choices.extend(choices_data) - response = ChatCompletionResponse( + response = OmniChatCompletionResponse( id=request_id, created=created_time, model=model_name, @@ -1328,6 +1527,7 @@ async def chat_completion_full_generator( prompt_logprobs=prompt_logprobs, prompt_token_ids=prompt_token_ids, kv_transfer_params=kv_transfer_params, + metrics=response_metrics, ) # Log complete response if output logging is enabled @@ -1369,6 +1569,7 @@ def _create_text_choice( tokenizer: TokenizerLike, conversation: list[ConversationMessage], role: str, + reasoning_parser: ReasoningParser | None = None, ): final_res = omni_outputs.request_output if self.tool_call_id_type == "kimi_k2": @@ -1436,15 +1637,10 @@ def _create_text_choice( choices.append(choice_data) continue - if self.reasoning_parser: - try: - reasoning_parser = self.reasoning_parser(tokenizer) - except RuntimeError as e: - logger.exception("Error in reasoning parser creation.") - return self.create_error_response(str(e)) + if reasoning_parser: # If the reasoning parser is enabled, # tool calls are extracted exclusively from the content. - reasoning_content, content = reasoning_parser.extract_reasoning_content(output.text, request=request) + reasoning_content, content = reasoning_parser.extract_reasoning(output.text, request=request) if not request.include_reasoning: reasoning_content = None else: @@ -1526,7 +1722,7 @@ def _create_text_choice( tool_parser = self.tool_parser(tokenizer) except RuntimeError as e: logger.exception("Error in tool parser creation.") - return self.create_error_response(str(e)) + return self.create_error_response(e) tool_call_info = tool_parser.extract_tool_calls(content if content is not None else "", request=request) # In the OpenAI API the finish_reason is "tools_called" @@ -1611,23 +1807,42 @@ def _create_audio_choice( ): choices: list[ChatCompletionResponseChoice] = [] final_res = omni_outputs.request_output - audio_tensor = final_res.multimodal_output["audio"].float().detach().cpu().numpy() - - # Ensure audio is 1D (flatten if needed) - if audio_tensor.ndim > 1: - audio_tensor = audio_tensor.flatten() - - audio_obj = CreateAudio( - audio_tensor=audio_tensor, - sample_rate=24000, - response_format="wav", - speed=1.0, - stream_format="audio", - base64_encode=True, - ) + # HyperCLOVAXAudioPipeline (diffusion): audio is in omni_outputs.multimodal_output + # (final_res.request_output is None, so final_res.outputs == []). + # Qwen3-Omni pipeline: audio is in final_res.outputs[0].multimodal_output. + if final_res is not None and final_res.outputs: + audio_data = final_res.outputs[0].multimodal_output.get("audio") + else: + audio_data = omni_outputs.multimodal_output.get("audio") + # HyperCLOVAXAudioPipeline post-process returns bytes (WAV/PCM). + # Qwen3-Omni returns tensors or list-of-tensors. + if isinstance(audio_data, bytes): + audio_base64 = base64.b64encode(audio_data).decode("ascii") + else: + if isinstance(audio_data, list): + if stream: + audio_tensor = audio_data[-1] + else: + audio_tensor = torch.cat(audio_data, dim=-1) + else: + audio_tensor = audio_data + audio_tensor = audio_tensor.float().detach().cpu().numpy() + + # Ensure audio is 1D (flatten if needed) + if audio_tensor.ndim > 1: + audio_tensor = audio_tensor.flatten() + + audio_obj = CreateAudio( + audio_tensor=audio_tensor, + sample_rate=24000, + response_format="wav", + speed=1.0, + stream_format="audio", + base64_encode=True, + ) - audio_response: AudioResponse = self.create_audio(audio_obj) - audio_base64 = audio_response.audio_data + audio_response: AudioResponse = self.create_audio(audio_obj) + audio_base64 = audio_response.audio_data # Generate unique ID for the audio audio_id = f"audio-{uuid.uuid4().hex[:16]}" @@ -1643,19 +1858,21 @@ def _create_audio_choice( transcript="", # Empty transcript if not available ) - for output in final_res.outputs: + _output_list = (final_res.outputs if (final_res is not None and final_res.outputs) + else [None]) + for _choice_idx, output in enumerate(_output_list): if stream: choice_data = ChatCompletionResponseStreamChoice( - index=output.index, + index=output.index if output is not None else _choice_idx, delta=DeltaMessage(role=role, content=audio_base64), logprobs=None, finish_reason="stop", - stop_reason=output.stop_reason, - token_ids=(as_list(output.token_ids) if request.return_token_ids else None), + stop_reason=output.stop_reason if output is not None else None, + token_ids=(as_list(output.token_ids) if (output is not None and request.return_token_ids) else None), ) else: choice_data = ChatCompletionResponseChoice( - index=output.index, + index=output.index if output is not None else _choice_idx, message=ChatMessage(role=role, audio=audio_obj), logprobs=None, finish_reason="stop", @@ -1691,9 +1908,11 @@ def _create_image_choice( if omni_outputs.images: images = omni_outputs.images # Fall back to request_output for pipeline mode - elif final_res is not None: - if hasattr(final_res, "multimodal_output") and final_res.multimodal_output: - image_data = final_res.multimodal_output.get("image") + # OMNI: Access multimodal_output from CompletionOutput (outputs[0]), not from RequestOutput + elif final_res is not None and final_res.outputs: + completion_output = final_res.outputs[0] + if hasattr(completion_output, "multimodal_output") and completion_output.multimodal_output: + image_data = completion_output.multimodal_output.get("image") if image_data is not None: if isinstance(image_data, Image.Image): images.append(image_data) @@ -1834,6 +2053,7 @@ async def _create_diffusion_chat_completion( # Text-to-video parameters (ref: text_to_video.py) num_frames = extra_body.get("num_frames") guidance_scale_2 = extra_body.get("guidance_scale_2") # For video high-noise CFG + lora_body = extra_body.get("lora") logger.info( "Diffusion chat request %s: prompt=%r, ref_images=%d, params=%s", @@ -1853,34 +2073,63 @@ async def _create_diffusion_chat_completion( logger.warning("Failed to decode reference image: %s", e) # Build generation kwargs - gen_kwargs: dict[str, Any] = { + gen_prompt: OmniTextPrompt = { "prompt": prompt, - "request_id": request_id, - "num_inference_steps": num_inference_steps, - "height": height, - "width": width, "negative_prompt": negative_prompt, - "num_outputs_per_prompt": num_outputs_per_prompt, - "seed": seed, } + gen_params = OmniDiffusionSamplingParams( + num_inference_steps=num_inference_steps, + height=height, + width=width, + num_outputs_per_prompt=num_outputs_per_prompt, + seed=seed, + ) if guidance_scale is not None: - gen_kwargs["guidance_scale"] = guidance_scale + gen_params.guidance_scale = guidance_scale # Add Qwen-Image specific parameter if true_cfg_scale is not None: - gen_kwargs["true_cfg_scale"] = true_cfg_scale + gen_params.true_cfg_scale = true_cfg_scale # Add video generation parameters if set if num_frames is not None: - gen_kwargs["num_frames"] = num_frames + gen_params.num_frames = num_frames if guidance_scale_2 is not None: - gen_kwargs["guidance_scale_2"] = guidance_scale_2 + gen_params.guidance_scale_2 = guidance_scale_2 + + # Parse per-request LoRA (works for both AsyncOmniDiffusion and AsyncOmni). + if lora_body and isinstance(lora_body, dict): + try: + lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") + lora_path = ( + lora_body.get("local_path") + or lora_body.get("path") + or lora_body.get("lora_path") + or lora_body.get("lora_local_path") + ) + # using "or" directly here may be buggy if `scale=0` + lora_scale = lora_body.get("scale") + if lora_scale is None: + lora_scale = lora_body.get("lora_scale") + lora_int_id = lora_body.get("int_id") + if lora_int_id is None: + lora_int_id = lora_body.get("lora_int_id") + if lora_int_id is None and lora_path: + lora_int_id = stable_lora_int_id(str(lora_path)) + if lora_name and lora_path: + lora_req = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) + gen_params.lora_request = lora_req + if lora_scale is not None: + gen_params.lora_scale = float(lora_scale) + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) # Add reference image if provided if pil_images: if len(pil_images) == 1: - gen_kwargs["pil_image"] = pil_images[0] + gen_prompt["multi_modal_data"] = {} + gen_prompt["multi_modal_data"]["image"] = pil_images[0] else: od_config = getattr(self._diffusion_engine, "od_config", None) supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) @@ -1888,7 +2137,8 @@ async def _create_diffusion_chat_completion( # TODO: entry is asyncOmni. We hack the od config here. supports_multimodal_inputs = True if supports_multimodal_inputs: - gen_kwargs["pil_image"] = pil_images + gen_prompt["multi_modal_data"] = {} + gen_prompt["multi_modal_data"]["image"] = pil_images else: return self._create_error_response( "Multiple input images are not supported by the current diffusion model. " @@ -1901,18 +2151,24 @@ async def _create_diffusion_chat_completion( # Handle both AsyncOmniDiffusion (returns OmniRequestOutput) and AsyncOmni (returns AsyncGenerator) if hasattr(self._diffusion_engine, "stage_list"): # AsyncOmni: iterate through async generator to get final output + diffusion_engine = cast(AsyncOmni, self._diffusion_engine) result = None - async for output in self._diffusion_engine.generate( - prompt=gen_kwargs["prompt"], - request_id=gen_kwargs.get("request_id"), - sampling_params_list=[gen_kwargs], # Pass as single-stage params + async for output in diffusion_engine.generate( + prompt=gen_prompt, + sampling_params_list=[gen_params], # Pass as single-stage params + request_id=request_id, ): result = output if result is None: return self._create_error_response("No output generated from AsyncOmni") else: # AsyncOmniDiffusion: direct call - result = await self._diffusion_engine.generate(**gen_kwargs) + diffusion_engine = cast(AsyncOmniDiffusion, self._diffusion_engine) + result = await diffusion_engine.generate( + prompt=gen_prompt, + sampling_params=gen_params, + request_id=request_id, + ) # Extract images from result # Handle nested OmniRequestOutput structure where images might be in request_output images = getattr(result.request_output, "images", []) @@ -2049,7 +2305,9 @@ def _create_error_response( ) -> ErrorResponse: """Create an error response following OpenAI error format.""" return ErrorResponse( - message=message, - type=err_type, - code=status_code, + error=ErrorInfo( + message=message, + type=err_type, + code=status_code, + ) ) diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py index 71c5e8377ac..0679d4a988c 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py @@ -21,14 +21,26 @@ SupportsPP, ) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.qwen2_5_omni_thinker import ( - Qwen2_5OmniAudioFeatureInputs, - Qwen2_5OmniThinkerDummyInputsBuilder, - Qwen2_5OmniThinkerMultiModalProcessor, - Qwen2_5OmniThinkerProcessingInfo, - get_llm_pos_ids_for_vision, - split_list_into_ranges, -) +try: + from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniAudioFeatureInputs, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerProcessingInfo, + check_interleaved_audio_video, + merge_interleaved_embeddings, + ) +except ImportError: + from vllm.model_executor.models.qwen2_5_omni_thinker import ( # type: ignore[no-redef] + Qwen2_5OmniAudioFeatureInputs, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerProcessingInfo, + ) + + def check_interleaved_audio_video(*a, **k): + return False + + def merge_interleaved_embeddings(*a, **k): + raise NotImplementedError("merge_interleaved_embeddings not available") from vllm.model_executor.models.qwen2_5_omni_thinker import ( Qwen2_5OmniConditionalGenerationMixin as Qwen2_5OmniConditionalGenerationMixinBase, ) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py index 361a9349b25..ef58a315081 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py @@ -56,11 +56,27 @@ SupportsMultiModal, SupportsPP, ) -from vllm.model_executor.models.qwen2_5_omni_thinker import ( - Qwen2_5OmniAudioFeatureInputs, - Qwen2_5OmniThinkerDummyInputsBuilder, - Qwen2_5OmniThinkerMultiModalProcessor, -) +from vllm.model_executor.models.module_mapping import MultiModelKeys +try: + from vllm.model_executor.models.qwen2_5_omni_thinker import ( + Qwen2_5OmniAudioFeatureInputs, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + check_interleaved_audio_video, + merge_interleaved_embeddings, + ) +except ImportError: + from vllm.model_executor.models.qwen2_5_omni_thinker import ( # type: ignore[no-redef] + Qwen2_5OmniAudioFeatureInputs, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + ) + + def check_interleaved_audio_video(*a, **k): + return False + + def merge_interleaved_embeddings(*a, **k): + raise NotImplementedError("merge_interleaved_embeddings not available") from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VLProcessingInfo, ) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 2d2b7ef8e2d..0ac857d1564 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -427,7 +427,7 @@ def propose_draft_token_ids(sampled_token_ids): sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), + pooler_output=(pooler_output if getattr(self.vllm_config.model_config, "engine_output_type", "text") != "text" else None), kv_connector_output=kv_connector_output, ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, num_nans_in_logits=num_nans_in_logits, From 6fac2d7a7c54b1047c8ad683d5e7d3e1af7ce3ba Mon Sep 17 00:00:00 2001 From: kje Date: Mon, 6 Apr 2026 09:18:33 +0900 Subject: [PATCH 3/4] fix: diffusion IPC, audio/vision decoder E2E pipeline fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - diffusion/ipc.py, diffusion_engine.py, diffusion_worker.py: IPC stability and worker lifecycle fixes for HCX audio+vision stages - diffusion/models/hyperclovax_audio/pipeline_hyperclovax_audio.py: finetuned audio decoder path, transformers_modules deserialization, zero-shot speaker embedding fallback - diffusion/registry.py, request.py: HCX Omni diffusion model registration and request type handling Validated E2E with HyperCLOVAX-SEED-Omni-8B: Speech-to-Speech → 11.84s / 568KB WAV (BigVGAN, 24kHz) Text-to-Vision → 768×768 PNG (diffusion, 50 steps) Co-Authored-By: 길재은 Co-Authored-By: Hyunjoon Jeong --- vllm_omni/diffusion/diffusion_engine.py | 43 +- vllm_omni/diffusion/ipc.py | 131 +++ .../models/hyperclovax_audio/__init__.py | 2 +- .../models/hyperclovax_audio/activations.py | 91 +-- .../models/hyperclovax_audio/constants.py | 16 +- .../models/hyperclovax_audio/ecapa_tdnn.py | 170 ++-- .../hyperclovax_audio_decoder.py | 452 ++++------- .../pipeline_hyperclovax_audio.py | 408 +++++----- vllm_omni/diffusion/registry.py | 183 ++++- vllm_omni/diffusion/request.py | 197 +---- .../diffusion/worker/diffusion_worker.py | 751 ++++++++++++++++++ 11 files changed, 1589 insertions(+), 855 deletions(-) create mode 100644 vllm_omni/diffusion/ipc.py create mode 100644 vllm_omni/diffusion/worker/diffusion_worker.py diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 0f21d4b2820..91e17b2d89b 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -415,21 +415,46 @@ def _dummy_run(self): prompt = "dummy run" # note that num_inference_steps=1 will cause timestep and temb None in the pipeline num_inference_steps = 1 - height = 1024 - width = 1024 + height = 256 + width = 256 if supports_image_input(self.od_config.model_class_name): # Provide a dummy image input if the model supports it dummy_image = PIL.Image.new("RGB", (width, height), color=(0, 0, 0)) else: - dummy_image = None + dummy_audio = None + + # Collect dummy extra tokens from the pipeline class if available. + # Some pipelines (e.g. HyperCLOVAXVisionPipeline) require tokens in + # req.extra that are normally populated by stage input processors. + model_cls = DiffusionModelRegistry._try_load_model_cls( + self.od_config.model_class_name + ) + dummy_extra = {} + if model_cls is not None and hasattr(model_cls, "get_dummy_extra"): + dummy_extra = model_cls.get_dummy_extra() + + prompt: OmniTextPrompt = { + "prompt": "dummy run", + "multi_modal_data": {"image": dummy_image, "audio": dummy_audio}, + } req = OmniDiffusionRequest( - prompt=prompt, - height=height, - width=width, - pil_image=dummy_image, - num_inference_steps=num_inference_steps, - num_outputs_per_prompt=1, + prompts=[prompt], + request_ids=["dummy_req_id"], + extra=dummy_extra, + sampling_params=OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=num_inference_steps, + # Keep warmup path minimal and robust across text encoders. + # Some models may fail when warmup implicitly triggers + # classifier-free guidance with an empty negative prompt. + guidance_scale=0.0, + num_outputs_per_prompt=1, + # Disable CFG for warmup to avoid triggering CFG parallel + # validation when cfg_parallel_size > 1. + extra_args={"cfg_text_scale": 1.0, "cfg_img_scale": 1.0}, + ), ) logger.info("dummy run to warm up the model") requests = self.pre_process_func([req]) if self.pre_process_func is not None else [req] diff --git a/vllm_omni/diffusion/ipc.py b/vllm_omni/diffusion/ipc.py new file mode 100644 index 00000000000..93203cf2eaf --- /dev/null +++ b/vllm_omni/diffusion/ipc.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""IPC utilities for transferring large tensors via POSIX shared memory. + +Used by Hop1 (GPU worker <-> scheduler) to avoid pickling large video tensors +through the MessageQueue. Tensors above ``_SHM_TENSOR_THRESHOLD`` are copied +into a named shared-memory segment; only a lightweight metadata dict is +serialised through the queue. +""" + +from __future__ import annotations + +from typing import Any + +import torch + +from vllm_omni.diffusion.data import DiffusionOutput + +_SHM_TENSOR_THRESHOLD = 0 # Always use SHM for CUDA tensor safety + + +def _tensor_to_shm(tensor: torch.Tensor) -> dict[str, Any]: + """Copy a tensor into POSIX shared memory and return a metadata handle. + + The shared memory segment remains alive after this call (the local fd is + closed, but the segment persists until ``_tensor_from_shm`` unlinks it). + + BFloat16 and other numpy-incompatible dtypes are stored as raw uint8 bytes + and reconstructed using the stored ``torch_dtype``. + """ + from multiprocessing import shared_memory + + import numpy as np + + orig_dtype = tensor.dtype + tensor = tensor.detach().cpu().contiguous() + # BFloat16 (and some other dtypes) are not natively supported by numpy. + # Use a raw uint8 byte view so data can be round-tripped without precision loss. + try: + arr = tensor.numpy() + use_raw_bytes = False + except TypeError: + arr = tensor.view(torch.uint8).numpy() + use_raw_bytes = True + nbytes = arr.nbytes + shm = shared_memory.SharedMemory(create=True, size=nbytes) + shm_arr = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf[:nbytes]) + np.copyto(shm_arr, arr) + handle = { + "__tensor_shm__": True, + "name": shm.name, + "shape": list(tensor.shape), + "torch_dtype": str(orig_dtype), + "numpy_dtype": str(arr.dtype), + "nbytes": nbytes, + "raw_bytes": use_raw_bytes, + } + shm.close() + return handle + + +def _tensor_from_shm(handle: dict[str, Any]) -> torch.Tensor: + """Reconstruct a tensor from a shared-memory handle and free the segment.""" + from multiprocessing import shared_memory + + import numpy as np + + shm = shared_memory.SharedMemory(name=handle["name"]) + try: + np_dtype = np.dtype(handle["numpy_dtype"]) + if handle.get("raw_bytes"): + # Data was stored as raw uint8 bytes (e.g. BFloat16 round-trip). + byte_arr = np.ndarray(handle["nbytes"], dtype=np.uint8, buffer=shm.buf[: handle["nbytes"]]) + raw = torch.from_numpy(byte_arr.copy()) + else: + arr = np.ndarray(handle["shape"], dtype=np_dtype, buffer=shm.buf[: handle["nbytes"]]) + raw = torch.from_numpy(arr.copy()) + finally: + shm.close() + shm.unlink() + # Restore the original torch dtype (handles BF16 raw-byte round-trip). + torch_dtype_str = handle["torch_dtype"].replace("torch.", "") + torch_dtype = getattr(torch, torch_dtype_str) + if raw.dtype != torch_dtype or handle.get("raw_bytes"): + raw = raw.view(torch_dtype).reshape(handle["shape"]) + return raw + + +def _pack_diffusion_fields(output: DiffusionOutput) -> DiffusionOutput: + if output.output is not None and isinstance(output.output, torch.Tensor): + if output.output.nelement() * output.output.element_size() > _SHM_TENSOR_THRESHOLD: + output.output = _tensor_to_shm(output.output) + if output.trajectory_latents is not None and isinstance(output.trajectory_latents, torch.Tensor): + if output.trajectory_latents.nelement() * output.trajectory_latents.element_size() > _SHM_TENSOR_THRESHOLD: + output.trajectory_latents = _tensor_to_shm(output.trajectory_latents) + return output + + +def pack_diffusion_output_shm(output: object) -> object: + """Replace large tensors in diffusion worker outputs with SHM handles. + + Supports either a bare ``DiffusionOutput`` or a wrapper object carrying one + in ``.result`` (for example ``RunnerOutput``). + """ + if isinstance(output, DiffusionOutput): + return _pack_diffusion_fields(output) + + result = getattr(output, "result", None) + if isinstance(result, DiffusionOutput): + output.result = _pack_diffusion_fields(result) + return output + + +def _unpack_diffusion_fields(output: DiffusionOutput) -> DiffusionOutput: + if isinstance(output.output, dict) and output.output.get("__tensor_shm__"): + output.output = _tensor_from_shm(output.output) + if isinstance(output.trajectory_latents, dict) and output.trajectory_latents.get("__tensor_shm__"): + output.trajectory_latents = _tensor_from_shm(output.trajectory_latents) + return output + + +def unpack_diffusion_output_shm(output: object) -> object: + """Reconstruct tensors from SHM handles in diffusion worker outputs.""" + if isinstance(output, DiffusionOutput): + return _unpack_diffusion_fields(output) + + result = getattr(output, "result", None) + if isinstance(result, DiffusionOutput): + output.result = _unpack_diffusion_fields(result) + return output diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/__init__.py b/vllm_omni/diffusion/models/hyperclovax_audio/__init__.py index 14cb19ddc31..cbbeca82ed3 100644 --- a/vllm_omni/diffusion/models/hyperclovax_audio/__init__.py +++ b/vllm_omni/diffusion/models/hyperclovax_audio/__init__.py @@ -10,7 +10,7 @@ ) __all__ = [ - "HyperCLOVAXAudioPipeline", "HyperCLOVAXAudioDecoderModel", + "HyperCLOVAXAudioPipeline", "get_hyperclovax_audio_post_process_func", ] diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/activations.py b/vllm_omni/diffusion/models/hyperclovax_audio/activations.py index 0e9aa4a2a35..1e1a6c0521e 100644 --- a/vllm_omni/diffusion/models/hyperclovax_audio/activations.py +++ b/vllm_omni/diffusion/models/hyperclovax_audio/activations.py @@ -11,14 +11,7 @@ if "sinc" in dir(torch): sinc = torch.sinc else: - # This code is adopted from adefossez's julius.core.sinc under the MIT License - # https://adefossez.github.io/julius/julius/core.html - # See NOTICE file for license details. def sinc(x: torch.Tensor): - """ - Implementation of sinc, i.e. sin(pi * x) / (pi * x) - __Warning__: Different to julius.sinc, the input is multiplied by `pi`! - """ return torch.where( x == 0, torch.tensor(1.0, device=x.device, dtype=x.dtype), @@ -26,14 +19,17 @@ def sinc(x: torch.Tensor): ) -# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License -# https://adefossez.github.io/julius/julius/lowpass.html -# See NOTICE file for license details. -def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + """Return filter [1, 1, kernel_size]. + + BUG FIX: Original PR #869 had two bugs here: + 1. Variable name typo: assigned to 'filter' but returned 'filter' (unbound in cutoff==0 path) + 2. cutoff==0 path didn't return properly + Both fixed by using 'filter_' consistently. + """ even = kernel_size % 2 == 0 half_size = kernel_size // 2 - # For kaiser window delta_f = 4 * half_width A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 if A > 50.0: @@ -44,23 +40,18 @@ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1 beta = 0.0 window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) - # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio if even: time = torch.arange(-half_size, half_size) + 0.5 else: time = torch.arange(kernel_size) - half_size + if cutoff == 0: filter_ = torch.zeros_like(time) else: filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) - """ - Normalize filter to have sum = 1, - otherwise we will have a small leakage of the constant component in the input signal. - """ filter_ /= filter_.sum() - filter = filter_.view(1, 1, kernel_size) - return filter + return filter_.view(1, 1, kernel_size) class LowPassFilter1d(nn.Module): @@ -74,9 +65,6 @@ def __init__( kernel_size: int = 12, causal: bool = False, ): - """ - kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible. - """ super().__init__() if cutoff < -0.0: raise ValueError("Minimum cutoff must be larger than zero.") @@ -97,16 +85,15 @@ def __init__( filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) self.register_buffer("filter", filter) - # Input [B, C, T] def forward(self, x, hidden_states=None): _, C, _ = x.shape - hs = x[..., -self.pad_left :] + hs = x[..., -self.pad_left:] if self.padding: if self.causal: if hidden_states is not None: assert hidden_states.shape[-1] >= self.pad_left - hidden_states = hidden_states[..., -self.pad_left :] + hidden_states = hidden_states[..., -self.pad_left:] x = torch.cat([hidden_states, x], dim=-1) else: x = F.pad(x, (self.pad_left, 0), mode="constant", value=0.0) @@ -126,23 +113,22 @@ def __init__(self, ratio=2, kernel_size=None, causal=False): self.pad = self.kernel_size // ratio - 1 self.causal = causal - self.half_left = (kernel_size - ratio) // 2 - self.half_right = (kernel_size - ratio + 1) // 2 + self.half_left = (self.kernel_size - ratio) // 2 + self.half_right = (self.kernel_size - ratio + 1) // 2 filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) self.register_buffer("filter", filter) - # x: [B, C, T] def forward(self, x, hidden_states=None): _, C, _ = x.shape - hs = x[..., -self.pad :] + hs = x[..., -self.pad:] pad_left = self.pad pad_right = 0 if self.causal else self.pad if hidden_states is not None: assert hidden_states.shape[-1] >= self.pad - hidden_states = hidden_states[..., -self.pad :] + hidden_states = hidden_states[..., -self.pad:] x = torch.cat([hidden_states, x], dim=-1) if pad_right > 0: x = F.pad(x, (0, pad_right), mode="replicate") @@ -176,70 +162,34 @@ def __init__(self, ratio=2, kernel_size=None, causal=False): def forward(self, x, hidden_states=None): xx, hs = self.lowpass(x, hidden_states) - return xx, hs.detach() class SnakeBeta(nn.Module): - """ - A modified Snake function which uses separate parameters for the magnitude of the periodic components - Shape: - - Input: (B, C, T) - - Output: (B, C, T), same shape as the input - Parameters: - - alpha - trainable parameter that controls frequency - - beta - trainable parameter that controls magnitude - References: - - This activation function is a modified version based on this paper - by Liu Ziyin, Tilman Hartwig, Masahito Ueda: - https://arxiv.org/abs/2006.08195 - Examples: - >>> a1 = snakebeta(256) - >>> x = torch.randn(256) - >>> x = a1(x) - """ + """SnakeBeta: x + (1/beta) * sin^2(x * alpha)""" def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): - """ - Initialization. - INPUT: - - in_features: shape of the input - - alpha - trainable parameter that controls frequency - - beta - trainable parameter that controls magnitude - alpha is initialized to 1 by default, higher values = higher-frequency. - beta is initialized to 1 by default, higher values = higher-magnitude. - alpha will be trained along with the rest of your model. - """ super().__init__() self.in_features = in_features - - # Initialize alpha self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # Log scale alphas initialized to zeros + if self.alpha_logscale: self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) - else: # Linear scale alphas initialized to ones + else: self.alpha = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable - self.no_div_by_zero = 0.000000001 def forward(self, x): - """ - Forward pass of the function. - Applies the function to the input elementwise. - SnakeBeta ∶= x + 1/b * sin^2 (xa) - """ - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) - return x @@ -260,7 +210,6 @@ def __init__( self.upsample = UpSample1d(up_ratio, up_kernel_size, causal) self.downsample = DownSample1d(down_ratio, down_kernel_size, causal) - # x: [B,C,T] def forward(self, x, hidden_states=None): if hidden_states is None: hidden_states = [None] * 2 diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/constants.py b/vllm_omni/diffusion/models/hyperclovax_audio/constants.py index 6b993d06b75..b8888eb23e0 100644 --- a/vllm_omni/diffusion/models/hyperclovax_audio/constants.py +++ b/vllm_omni/diffusion/models/hyperclovax_audio/constants.py @@ -15,14 +15,14 @@ DEFAULT_FORMAT = "wav" AUDIO_FORMAT_MAP = [ - (b"RIFF", "wav"), # WAV (RIFF container) - (b"\x1a\x45\xdf\xa3", "webm"), # WebM / MKV (EBML header) - (b"OggS", "ogg"), # OGG - (b"fLaC", "flac"), # FLAC - (b"ID3", "mp3"), # MP3 with ID3 tag - (b"\xff\xfb", "mp3"), # MP3 without ID3 - (b"\x00\x00\x00\x1c", "mp4"), # MP4 / M4A - (b"\x00\x00\x00\x20", "mp4"), # MP4 / M4A + (b"RIFF", "wav"), + (b"\x1a\x45\xdf\xa3", "webm"), + (b"OggS", "ogg"), + (b"fLaC", "flac"), + (b"ID3", "mp3"), + (b"\xff\xfb", "mp3"), + (b"\x00\x00\x00\x1c", "mp4"), + (b"\x00\x00\x00\x20", "mp4"), ] VOLUME_LEVEL_DB = -26 diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/ecapa_tdnn.py b/vllm_omni/diffusion/models/hyperclovax_audio/ecapa_tdnn.py index d9eb67c0ba6..72619088466 100644 --- a/vllm_omni/diffusion/models/hyperclovax_audio/ecapa_tdnn.py +++ b/vllm_omni/diffusion/models/hyperclovax_audio/ecapa_tdnn.py @@ -2,11 +2,11 @@ # Copyright (c) 2025-present NAVER Cloud Corp. # Apache-2.0 # -# This is the ECAPA-TDNN model. -# "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification" +# ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in +# TDNN Based Speaker Verification. # https://arxiv.org/pdf/2005.07143 # -# This model is modified based on the following projects: +# Modified from: # - https://github.com/lawlict/ECAPA-TDNN/blob/master/ecapa_tdnn.py # - https://github.com/TaoRuijie/ECAPA-TDNN/blob/main/model.py (MIT License) @@ -16,44 +16,21 @@ class Res2Conv1dReluBn(nn.Module): - """ - Res2Conv1d + BatchNorm1d + ReLU - NOTE: in_channels == out_channels == channels - """ - - def __init__( - self, - channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - bias=False, - scale=4, - ): + def __init__(self, channels, kernel_size=1, stride=1, padding=0, + dilation=1, bias=False, scale=4): super().__init__() assert channels % scale == 0, f"{channels} % {scale} != 0" self.scale = scale self.width = channels // scale self.nums = scale if scale == 1 else scale - 1 - self.convs = [] - self.bns = [] - for i in range(self.nums): + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + for _ in range(self.nums): self.convs.append( - nn.Conv1d( - self.width, - self.width, - kernel_size, - stride, - padding, - dilation, - bias=bias, - ) - ) + nn.Conv1d(self.width, self.width, kernel_size, stride, + padding, dilation, bias=bias)) self.bns.append(nn.BatchNorm1d(self.width)) - self.convs = nn.ModuleList(self.convs) - self.bns = nn.ModuleList(self.bns) def forward(self, x): out = [] @@ -63,7 +40,6 @@ def forward(self, x): split = x_splits[i] else: split = split + x_splits[i] - # Order: conv -> relu -> bn split = self.convs[i](split) split = self.bns[i](F.relu(split)) out.append(split) @@ -74,20 +50,11 @@ def forward(self, x): class Conv1dReluBn(nn.Module): - """Conv1d + BatchNorm1d + ReLU""" - - def __init__( - self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - bias=False, - ): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=0, dilation=1, bias=False): super().__init__() - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, + padding, dilation, bias=bias) self.bn = nn.BatchNorm1d(out_channels) def forward(self, x): @@ -95,8 +62,6 @@ def forward(self, x): class SE_Connect(nn.Module): - """The SE connection of 1D case.""" - def __init__(self, channels, bottleneck_dim): super().__init__() self.linear1 = nn.Linear(channels, bottleneck_dim) @@ -111,128 +76,86 @@ def forward(self, x): class SE_Res2Block(nn.Module): - """SE-Res2Block of the ECAPA-TDNN architecture.""" - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - scale, - se_bottleneck_dim, - ): + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, scale, se_bottleneck_dim): super().__init__() - self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale) - self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, + kernel_size=1, stride=1, padding=0) + self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, + stride, padding, dilation, + scale=scale) + self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, + kernel_size=1, stride=1, padding=0) self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) self.shortcut = None if in_channels != out_channels: - self.shortcut = nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - ) + self.shortcut = nn.Conv1d(in_channels=in_channels, + out_channels=out_channels, kernel_size=1) def forward(self, x): residual = x if self.shortcut: residual = self.shortcut(x) - x = self.Conv1dReluBn1(x) x = self.Res2Conv1dReluBn(x) x = self.Conv1dReluBn2(x) x = self.SE_Connect(x) - return x + residual class AttentiveStatsPool(nn.Module): - def __init__(self, in_dim, attention_channels=128, global_context_att=False): + def __init__(self, in_dim, attention_channels=128, + global_context_att=False): super().__init__() self.global_context_att = global_context_att - if global_context_att: - self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper + self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, + kernel_size=1) else: - self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper - self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper + self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) + self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) def forward(self, x): if self.global_context_att: context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) - context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) x_in = torch.cat((x, context_mean, context_std), dim=1) else: x_in = x - - # DON'T use ReLU here! In experiments, I find ReLU hard to converge. alpha = torch.tanh(self.linear1(x_in)) alpha = torch.softmax(self.linear2(alpha), dim=2) mean = torch.sum(alpha * x, dim=2) - residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 + residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2 std = torch.sqrt(residuals.clamp(min=1e-9)) return torch.cat([mean, std], dim=1) -""" Implementation of - "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification". - - Note that we DON'T concatenate the last frame-wise layer with non-weighted mean and standard deviation, - because it brings little improvement but significantly increases model parameters. - As a result, this implementation basically equals the A.2 of Table 2 in the paper. -""" - - class ECAPA_TDNN(nn.Module): - def __init__(self, in_channel=100, hidden_channel=512, emb_dim=256, global_context_att=False): + def __init__(self, in_channel=100, hidden_channel=512, emb_dim=256, + global_context_att=False): super().__init__() self.instance_norm = nn.InstanceNorm1d(in_channel) self.channels = [hidden_channel] * 4 + [hidden_channel * 3] - self.layer1 = Conv1dReluBn(in_channel, self.channels[0], kernel_size=5, padding=2) + self.layer1 = Conv1dReluBn(in_channel, self.channels[0], + kernel_size=5, padding=2) self.layer2 = SE_Res2Block( - self.channels[0], - self.channels[1], - kernel_size=3, - stride=1, - padding=2, - dilation=2, - scale=8, - se_bottleneck_dim=128, - ) + self.channels[0], self.channels[1], kernel_size=3, stride=1, + padding=2, dilation=2, scale=8, se_bottleneck_dim=128) self.layer3 = SE_Res2Block( - self.channels[1], - self.channels[2], - kernel_size=3, - stride=1, - padding=3, - dilation=3, - scale=8, - se_bottleneck_dim=128, - ) + self.channels[1], self.channels[2], kernel_size=3, stride=1, + padding=3, dilation=3, scale=8, se_bottleneck_dim=128) self.layer4 = SE_Res2Block( - self.channels[2], - self.channels[3], - kernel_size=3, - stride=1, - padding=4, - dilation=4, - scale=8, - se_bottleneck_dim=128, - ) + self.channels[2], self.channels[3], kernel_size=3, stride=1, + padding=4, dilation=4, scale=8, se_bottleneck_dim=128) cat_channels = hidden_channel * 3 self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) self.pooling = AttentiveStatsPool( - self.channels[-1], - attention_channels=128, - global_context_att=global_context_att, - ) + self.channels[-1], attention_channels=128, + global_context_att=global_context_att) self.bn = nn.BatchNorm1d(self.channels[-1] * 2) self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) self.bn_out = nn.BatchNorm1d(emb_dim) @@ -242,7 +165,6 @@ def forward(self, x): out2 = self.layer2(out1) out3 = self.layer3(out2) out4 = self.layer4(out3) - out = torch.cat([out2, out3, out4], dim=1) out = F.relu(self.conv(out)) out = self.bn(self.pooling(out)) diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/hyperclovax_audio_decoder.py b/vllm_omni/diffusion/models/hyperclovax_audio/hyperclovax_audio_decoder.py index 58af36ffd45..0cbab9c8839 100644 --- a/vllm_omni/diffusion/models/hyperclovax_audio/hyperclovax_audio_decoder.py +++ b/vllm_omni/diffusion/models/hyperclovax_audio/hyperclovax_audio_decoder.py @@ -5,6 +5,7 @@ # Portions from https://github.com/NVIDIA/BigVGAN under the MIT license. # See NOTICE file for license details. +import inspect import json import math from pathlib import Path @@ -13,13 +14,18 @@ import torch.nn as nn from torch.nn.utils import remove_weight_norm from torch.nn.utils.parametrizations import weight_norm +from vllm.logger import init_logger from vllm_omni.diffusion.data import OmniDiffusionConfig -from vllm_omni.diffusion.models.hyperclovax_audio.activations import Activation1d, SnakeBeta +from vllm_omni.diffusion.models.hyperclovax_audio.activations import ( + Activation1d, + SnakeBeta, +) from vllm_omni.diffusion.models.hyperclovax_audio.ecapa_tdnn import ECAPA_TDNN +logger = init_logger(__name__) + -# Dataclass for model hyper-parameters class AttrDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -32,7 +38,6 @@ def load_hparams_from_json(path) -> AttrDict: return AttrDict(json.loads(data)) -# Functions for model initialization def init_weights(m, mean=0.0, std=0.01): if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): m.weight.data.normal_(mean, std) @@ -43,84 +48,45 @@ def get_padding(kernel_size, dilation=1): class CausalConv1d(nn.Module): - """1D causal convloution w/ 1-side padding.""" - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - pad_buffer=None, - ): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + pad_buffer=None): super().__init__() self.conv = weight_norm( - nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=0, - dilation=dilation, - groups=groups, - bias=bias, - ) - ) + nn.Conv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=0, + dilation=dilation, groups=groups, bias=bias)) self.stride = stride self.pad_length = (kernel_size - 1) * dilation - - # TODO: deprecate pad_buffer and inference. Remove in the future if pad_buffer is None: pad_buffer = torch.zeros(1, in_channels, self.pad_length) self.register_buffer("pad_buffer", pad_buffer) def forward(self, x, hidden_states=None): if hidden_states is None: - x = nn.functional.pad(x, (self.pad_length, 0), "constant", value=0.0) + x = nn.functional.pad(x, (self.pad_length, 0), "constant", + value=0.0) else: assert hidden_states.shape[-1] >= self.pad_length - hidden_states = hidden_states[:, :, -self.pad_length :] + hidden_states = hidden_states[:, :, -self.pad_length:] x = torch.cat((hidden_states, x), -1) - return self.conv(x), x[:, :, -self.pad_length :].detach() + return self.conv(x), x[:, :, -self.pad_length:].detach() class CausalConvTranspose1d(nn.Module): - """1D causal transpose convloution.""" - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding=0, - output_padding=0, - groups=1, - bias=True, - pad_buffer=None, - ): + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding=0, output_padding=0, groups=1, bias=True, + pad_buffer=None): super().__init__() self.deconv = weight_norm( - nn.ConvTranspose1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=0, - output_padding=0, - groups=groups, - bias=bias, - ) - ) + nn.ConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, stride=stride, + padding=0, output_padding=0, groups=groups, + bias=bias)) self.stride = stride self.pad_length = math.ceil(kernel_size / stride) - 1 self.pad = nn.ReplicationPad1d((self.pad_length, 0)) - - # TODO: deprecate pad_buffer and inference. Remove in the future if pad_buffer is None: pad_buffer = torch.zeros(1, in_channels, self.pad_length) self.register_buffer("pad_buffer", pad_buffer) @@ -130,41 +96,22 @@ def forward(self, x, hidden_states=None): x = self.pad(x) else: assert hidden_states.shape[-1] >= self.pad_length - hidden_states = hidden_states[:, :, -self.pad_length :] + hidden_states = hidden_states[:, :, -self.pad_length:] x = torch.cat((hidden_states, x), -1) return ( - self.deconv(x)[:, :, self.stride : -self.stride], - x[:, :, -self.pad_length :].detach(), + self.deconv(x)[:, :, self.stride:-self.stride], + x[:, :, -self.pad_length:].detach(), ) class NonCausalConv1d(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - **kwargs, - ): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, **kwargs): super().__init__() self.conv = weight_norm( - nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - **kwargs, - ) - ) + nn.Conv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias, **kwargs)) self.pad_length = ((kernel_size - 1) * dilation) // 2 def forward(self, x, hidden_states=None): @@ -172,39 +119,22 @@ def forward(self, x, hidden_states=None): out = self.conv(x) else: assert hidden_states.shape[-1] >= self.pad_length - hidden_states = hidden_states[:, :, -self.pad_length :] + hidden_states = hidden_states[:, :, -self.pad_length:] x_ = torch.cat((hidden_states, x), -1) - out = self.conv(x_)[:, :, self.pad_length :] - return out, x[:, :, -self.pad_length :].detach() + out = self.conv(x_)[:, :, self.pad_length:] + return out, x[:, :, -self.pad_length:].detach() class NonCausalConvTranspose1d(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - output_padding=0, - groups=1, - bias=True, - **kwargs, - ): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, **kwargs): super().__init__() self.deconv = weight_norm( - nn.ConvTranspose1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - bias=bias, - **kwargs, - ) - ) + nn.ConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, output_padding=output_padding, + groups=groups, bias=bias, **kwargs)) self.stride = stride self.pad_length = (kernel_size - stride) // 2 @@ -213,90 +143,50 @@ def forward(self, x, hidden_states=None): out = self.deconv(x) else: assert hidden_states.shape[-1] >= self.pad_length - hidden_states = hidden_states[:, :, -self.pad_length :] + hidden_states = hidden_states[:, :, -self.pad_length:] x_ = torch.cat((hidden_states, x), -1) - out = self.deconv(x_)[:, :, self.pad_length * self.stride :] - return out, x[:, :, -self.pad_length :].detach() + out = self.deconv(x_)[:, :, self.pad_length * self.stride:] + return out, x[:, :, -self.pad_length:].detach() class AMPBlock1(torch.nn.Module): - """ - AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters - that control periodicity, defined for each layer. - AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 - followed by each layer in self.convs1 - - Args: - h (AttrDict): Hyperparameters. - channels (int): Number of convolution channels. - kernel_size (int): Size of the convolution kernel. Default is 3. - dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. - Default is (1, 3, 5). - activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. - Default is None. - """ + """Anti-aliased Multi-Periodicity residual block with SnakeBeta activations.""" - def __init__( - self, - h: AttrDict, - channels: int, - kernel_size: int = 3, - dilation: tuple = (1, 3, 5), - activation: str = None, - causal: bool = False, - ): + def __init__(self, h: AttrDict, channels: int, kernel_size: int = 3, + dilation: tuple = (1, 3, 5), activation: str = None, + causal: bool = False): super().__init__() conv1d = CausalConv1d if causal else NonCausalConv1d - self.h = h - self.convs1 = nn.ModuleList( - [ - conv1d( - channels, - channels, - kernel_size, - stride=1, - dilation=d, - padding=get_padding(kernel_size, d), - ) - for d in dilation - ] - ) + self.convs1 = nn.ModuleList([ + conv1d(channels, channels, kernel_size, stride=1, dilation=d, + padding=get_padding(kernel_size, d)) + for d in dilation + ]) self.convs1.apply(init_weights) - self.convs2 = nn.ModuleList( - [ - conv1d( - channels, - channels, - kernel_size, - stride=1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - for _ in range(len(dilation)) - ] - ) + self.convs2 = nn.ModuleList([ + conv1d(channels, channels, kernel_size, stride=1, dilation=1, + padding=get_padding(kernel_size, 1)) + for _ in range(len(dilation)) + ]) self.convs2.apply(init_weights) - self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers + self.num_layers = len(self.convs1) + len(self.convs2) - # Activation functions if activation == "snakebeta": - self.activations = nn.ModuleList( - [ - Activation1d( - activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale), - causal=False, - ) - for _ in range(self.num_layers) - ] - ) + self.activations = nn.ModuleList([ + Activation1d( + activation=SnakeBeta(channels, + alpha_logscale=h.snake_logscale), + causal=False) + for _ in range(self.num_layers) + ]) else: raise NotImplementedError( - "activation incorrectly specified. check the config file and look for 'activation'." - ) + "activation incorrectly specified. check the config file " + "and look for 'activation'.") def forward(self, x, hidden_states=None): if hidden_states is None: @@ -304,7 +194,8 @@ def forward(self, x, hidden_states=None): hidden_states_new = [] acts1, acts2 = self.activations[::2], self.activations[1::2] - for c1, c2, a1, a2, (h_a1, h_c1, h_a2, h_c2) in zip(self.convs1, self.convs2, acts1, acts2, hidden_states): + for c1, c2, a1, a2, (h_a1, h_c1, h_a2, h_c2) in zip( + self.convs1, self.convs2, acts1, acts2, hidden_states): xt, ht_a1 = a1(x, h_a1) xt, ht_c1 = c1(xt, h_c1) xt, ht_a2 = a2(xt, h_a2) @@ -322,15 +213,10 @@ def remove_weight_norm(self): class HyperCLOVAXAudioDecoderModel(nn.Module): - """ - HyperCLOVAXAudioDecoderModel is a neural vocoder model that applies anti-aliased periodic activation - for residual blocks (resblocks). + """Unit-BigVGAN vocoder: discrete audio codes → 24kHz waveform. Args: - od_config (OmniDiffusionConfig): Configuration object containing model hyperparameters. - - Note: - Ensure that the activation function is correctly specified in the hyperparameters (h.activation). + od_config: OmniDiffusionConfig containing model hyperparameters. """ def __init__( @@ -343,7 +229,8 @@ def __init__( upsample_kernel_sizes: list[int] = [10, 8, 8, 6, 4, 4], upsample_initial_channel: int = 1536, resblock_kernel_sizes: list[int] = [3, 7, 11], - resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], + [1, 3, 5]], use_tanh_at_final: bool = False, use_bias_at_final: bool = False, activation: str = "snakebeta", @@ -364,45 +251,43 @@ def __init__( ): super().__init__() - self.h = AttrDict( - { - "resblock": resblock, - "causal": causal, - "finetune": finetune, - "upsample_rates": upsample_rates, - "upsample_kernel_sizes": upsample_kernel_sizes, - "upsample_initial_channel": upsample_initial_channel, - "resblock_kernel_sizes": resblock_kernel_sizes, - "resblock_dilation_sizes": resblock_dilation_sizes, - "use_tanh_at_final": use_tanh_at_final, - "use_bias_at_final": use_bias_at_final, - "activation": activation, - "snake_logscale": snake_logscale, - "num_units": num_units, - "unit_emb_dim": unit_emb_dim, - "num_mels": num_mels, - "n_fft": n_fft, - "hop_size": hop_size, - "win_size": win_size, - "spk_emb_dim": spk_emb_dim, - "spk_hidden_dim": spk_hidden_dim, - "global_context_att": global_context_att, - "sampling_rate": sampling_rate, - "fmin": fmin, - "fmax": fmax, - "num_spk": num_spk, - } - ) + self.h = AttrDict({ + "resblock": resblock, + "causal": causal, + "finetune": finetune, + "upsample_rates": upsample_rates, + "upsample_kernel_sizes": upsample_kernel_sizes, + "upsample_initial_channel": upsample_initial_channel, + "resblock_kernel_sizes": resblock_kernel_sizes, + "resblock_dilation_sizes": resblock_dilation_sizes, + "use_tanh_at_final": use_tanh_at_final, + "use_bias_at_final": use_bias_at_final, + "activation": activation, + "snake_logscale": snake_logscale, + "num_units": num_units, + "unit_emb_dim": unit_emb_dim, + "num_mels": num_mels, + "n_fft": n_fft, + "hop_size": hop_size, + "win_size": win_size, + "spk_emb_dim": spk_emb_dim, + "spk_hidden_dim": spk_hidden_dim, + "global_context_att": global_context_att, + "sampling_rate": sampling_rate, + "fmin": fmin, + "fmax": fmax, + "num_spk": num_spk, + }) self.causal = self.h.get("causal", True) conv1d = CausalConv1d if self.causal else NonCausalConv1d - convtranspose1d = CausalConvTranspose1d if self.causal else NonCausalConvTranspose1d + convtranspose1d = (CausalConvTranspose1d if self.causal + else NonCausalConvTranspose1d) self.num_kernels = len(self.h.resblock_kernel_sizes) self.num_upsamples = len(self.h.upsample_rates) self.finetune = getattr(self.h, "finetune", False) - # Speaker embedding if not self.finetune: self.spk_emb = ECAPA_TDNN( in_channel=self.h.num_mels, @@ -413,72 +298,63 @@ def __init__( else: self.spk_emb = nn.Embedding(self.h.num_spk, self.h.spk_emb_dim) - # Unit embedding self.unit_emb = nn.Embedding(self.h.num_units, self.h.unit_emb_dim) self.unit_emb_dim = self.h.unit_emb_dim - # Pre-conv self.conv_pre = conv1d( - self.h.unit_emb_dim + self.h.spk_emb_dim, self.h.upsample_initial_channel, 7, 1, padding=3 - ) + self.h.unit_emb_dim + self.h.spk_emb_dim, + self.h.upsample_initial_channel, 7, 1, padding=3) - # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default if self.h.resblock == "1": resblock_class = AMPBlock1 else: - raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {self.h.resblock}") + raise ValueError( + f"Incorrect resblock class specified. Got {self.h.resblock}") - # Transposed conv-based upsamplers. does not apply anti-aliasing self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(self.h.upsample_rates, self.h.upsample_kernel_sizes)): - self.ups.append( - nn.ModuleList( - [ - convtranspose1d( - self.h.upsample_initial_channel // (2**i), - self.h.upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=math.ceil((k - u) / 2), - output_padding=(k - u) % 2, - ) - ] + for i, (u, k) in enumerate( + zip(self.h.upsample_rates, self.h.upsample_kernel_sizes)): + self.ups.append(nn.ModuleList([ + convtranspose1d( + self.h.upsample_initial_channel // (2 ** i), + self.h.upsample_initial_channel // (2 ** (i + 1)), + k, u, + padding=math.ceil((k - u) / 2), + output_padding=(k - u) % 2, ) - ) + ])) - # Residual blocks using anti-aliased multi-periodicity composition modules (AMP) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = self.h.upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate(zip(self.h.resblock_kernel_sizes, self.h.resblock_dilation_sizes)): + for k, d in zip(self.h.resblock_kernel_sizes, + self.h.resblock_dilation_sizes): self.resblocks.append( - resblock_class(self.h, ch, k, d, activation=self.h.activation, causal=self.causal) - ) + resblock_class(self.h, ch, k, d, + activation=self.h.activation, + causal=self.causal)) - # Post-conv activation_post = ( - SnakeBeta(ch, alpha_logscale=self.h.snake_logscale) if self.h.activation == "snakebeta" else None - ) + SnakeBeta(ch, alpha_logscale=self.h.snake_logscale) + if self.h.activation == "snakebeta" else None) if activation_post is None: raise NotImplementedError( - "activation incorrectly specified. check the config file and look for 'activation'." - ) + "activation incorrectly specified. check the config file " + "and look for 'activation'.") - self.activation_post = Activation1d(activation=activation_post, causal=False) - - # Whether to use bias for the final conv_post. Default to True for backward compatibility + self.activation_post = Activation1d(activation=activation_post, + causal=False) self.use_bias_at_final = self.h.get("use_bias_at_final", True) - self.conv_post = conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final) + self.conv_post = conv1d(ch, 1, 7, 1, padding=3, + bias=self.use_bias_at_final) - # Weight initialization for i in range(len(self.ups)): self.ups[i].apply(init_weights) self.conv_post.apply(init_weights) - # Final tanh activation. Defaults to True for backward compatibility self.use_tanh_at_final = self.h.get("use_tanh_at_final", True) - - self.num_layers = self.num_upsamples + self.num_upsamples * self.num_kernels + 3 + self.num_layers = (self.num_upsamples + self.num_upsamples + * self.num_kernels + 3) def forward_with_spk_emb(self, x, spk_or_ref, hidden_states=None): spk_emb = self.spk_emb(spk_or_ref) @@ -488,15 +364,12 @@ def forward(self, x, spk_emb, hidden_states=None): if hidden_states is None: hidden_states = [None] * self.num_layers else: - assert len(hidden_states) == self.num_layers, ( - f"Expected hidden_states to have {self.num_layers} elements, but got {len(hidden_states)}." - ) + assert len(hidden_states) == self.num_layers hidden_state_iter = iter(hidden_states) hidden_states_new = [] - # Unit and speaker embedding - x = self.unit_emb(x).transpose(1, 2) * (self.unit_emb_dim**-0.5) + x = self.unit_emb(x).transpose(1, 2) * (self.unit_emb_dim ** -0.5) if self.finetune: spk_emb = spk_emb.transpose(1, 2).expand(-1, -1, x.shape[-1]) else: @@ -507,33 +380,32 @@ def forward(self, x, spk_emb, hidden_states=None): hidden_states_new.append(h) for i, up_layers in enumerate(self.ups): - # Upsampling for up_layer in up_layers: x, h = up_layer(x, next(hidden_state_iter)) hidden_states_new.append(h) - # AMP blocks resblock_outputs = [ - self.resblocks[i * self.num_kernels + j](x, next(hidden_state_iter)) for j in range(self.num_kernels) + self.resblocks[i * self.num_kernels + j]( + x, next(hidden_state_iter)) + for j in range(self.num_kernels) ] x = sum(o for o, _ in resblock_outputs) / self.num_kernels hidden_states_new.extend([h for _, h in resblock_outputs]) - # Post-conv x, h = self.activation_post(x, next(hidden_state_iter)) hidden_states_new.append(h) x, h = self.conv_post(x, next(hidden_state_iter)) hidden_states_new.append(h) - # Final tanh activation + if self.use_tanh_at_final: x = torch.tanh(x) else: - x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1] + x = torch.clamp(x, min=-1.0, max=1.0) return x, hidden_states_new def remove_weight_norm(self): try: - print("Removing weight norm...") + logger.info("Removing weight norm...") for layer in self.ups: for l_i in layer: remove_weight_norm(l_i) @@ -542,39 +414,39 @@ def remove_weight_norm(self): remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) except ValueError: - print("[INFO] Model already removed weight norm. Skipping!") - pass + logger.info("Model already removed weight norm. Skipping!") @classmethod - def from_pretrained( - cls, - ckpt_path: str, - config_path: str | None = None, - map_location: str = "cpu", # Additional argument - ): - """Load Pytorch pretrained weights and return the loaded model.""" + def from_pretrained(cls, ckpt_path: str, config_path: str | None = None, + map_location: str = "cpu"): + """Load pretrained weights and return the model. - # Load hyperparameters (h) used by BigVGAN + BUG FIX: Original PR #869 passed AttrDict to cls(h) but __init__ + expects OmniDiffusionConfig. Fixed to unpack AttrDict as kwargs. + """ if config_path is None: - print("Loading config.json from local directory") + logger.info("Loading config.json from local directory") config_path = Path(ckpt_path).with_name("config.json") h = load_hparams_from_json(config_path) - # instantiate BigVGAN using h - model = cls(h) + # Keep only kwargs supported by __init__; MAR configs contain + # training-only fields (e.g., num_gpus) that are irrelevant here. + init_sig = inspect.signature(cls.__init__) + valid_params = set(init_sig.parameters.keys()) - {"self", "od_config"} + init_kwargs = {k: v for k, v in dict(h).items() if k in valid_params} + + model = cls(od_config=None, **init_kwargs) - # Load pretrained generator weight - print("Loading weights from local directory") + logger.info("Loading weights from local directory") checkpoint_dict = torch.load(ckpt_path, map_location=map_location) try: model.load_state_dict(checkpoint_dict["generator"]) except RuntimeError: - print( - "[INFO] the pretrained checkpoint does not contain weight norm. " - "Loading the checkpoint after removing weight norm!" - ) + logger.info( + "Pretrained checkpoint does not contain weight norm. " + "Loading after removing weight norm.") model.remove_weight_norm() model.load_state_dict(checkpoint_dict["generator"]) diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/pipeline_hyperclovax_audio.py b/vllm_omni/diffusion/models/hyperclovax_audio/pipeline_hyperclovax_audio.py index 91ea577a81c..0562db769e6 100644 --- a/vllm_omni/diffusion/models/hyperclovax_audio/pipeline_hyperclovax_audio.py +++ b/vllm_omni/diffusion/models/hyperclovax_audio/pipeline_hyperclovax_audio.py @@ -1,8 +1,12 @@ import base64 import io +import json import math import os +import tempfile +import zipfile from collections.abc import Iterable +from pathlib import Path from typing import Any import librosa @@ -18,10 +22,18 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device -from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.model_loader.diffusers_loader import ( + DiffusersPipelineLoader, +) from vllm_omni.diffusion.request import OmniDiffusionRequest -from .constants import AUDIO_FORMAT_MAP, DEFAULT_FORMAT, FORMAT_MIME_MAP, SPEAKERS_LIST, VOLUME_LEVEL +from .constants import ( + AUDIO_FORMAT_MAP, + DEFAULT_FORMAT, + FORMAT_MIME_MAP, + SPEAKERS_LIST, + VOLUME_LEVEL, +) from .hyperclovax_audio_decoder import HyperCLOVAXAudioDecoderModel logger = init_logger(__name__) @@ -32,13 +44,11 @@ def get_hyperclovax_audio_post_process_func(od_config: OmniDiffusionConfig): - """ - Get post-processing function for HyperCLOVAX Audio pipeline. + """Get post-processing function for HyperCLOVAX Audio pipeline.""" - Returns a function that converts model output tensors to audio file. - """ - - def post_process_func(output: list[tuple[torch.Tensor, str]]) -> list[bytes]: + def post_process_func( + output: list[tuple[torch.Tensor, str]], + ) -> list[bytes]: response = [] for wav_tensor, fmt in output: wav = wav_tensor.squeeze().cpu().numpy() @@ -48,36 +58,62 @@ def post_process_func(output: list[tuple[torch.Tensor, str]]) -> list[bytes]: response.append(pcm.tobytes()) continue - segment = AudioSegment(pcm.tobytes(), frame_rate=24000, sample_width=pcm.dtype.itemsize, channels=1) - + segment = AudioSegment( + pcm.tobytes(), frame_rate=24000, + sample_width=pcm.dtype.itemsize, channels=1) buf = io.BytesIO() - export_kwargs = {"format": fmt} - segment.export(buf, **export_kwargs) + segment.export(buf, format=fmt if fmt is not None else "wav") response.append(buf.getvalue()) + # BUG FIX #1: Original PR #869 was missing this return statement + return response + return post_process_func class HyperCLOVAXAudioPipeline(nn.Module): - def __init__( - self, - *, - od_config: OmniDiffusionConfig, - prefix: str = "", - ): + support_audio_output: bool = True + + @staticmethod + def get_dummy_extra() -> dict: + """Return dummy extra dict for warmup dummy run.""" + # Minimal dummy: one empty token sequence, default speaker + return {"audio_tokens": [[0] * 10]} + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): super().__init__() self.od_config = od_config self.device = get_local_device() self._dtype = od_config.dtype self.model = self.od_config.model + self._using_mar_checkpoint = False + self._mar_extract_dir: str | None = None + + # Default path: diffusers-style weights in bigvgan/ subfolder. self.weights_sources = [ DiffusersPipelineLoader.ComponentSource( - model_or_path=od_config.model, subfolder="bigvgan", revision=None, prefix=None, fall_back_to_pt=True + model_or_path=od_config.model, + subfolder="bigvgan", + revision=None, + prefix="bigvgan.", + fall_back_to_pt=True, ) ] - self.bigvgan = HyperCLOVAXAudioDecoderModel(od_config=od_config).to(self.device) + mar_path = self._resolve_mar_path(self.model) + if mar_path is not None: + self._using_mar_checkpoint = True + self.weights_sources = [] + ckpt_path, config_path = self._extract_mar_checkpoint(mar_path) + self.bigvgan = HyperCLOVAXAudioDecoderModel.from_pretrained( + ckpt_path=ckpt_path, + config_path=config_path, + map_location="cpu", + ).to(self.device) + else: + self.bigvgan = HyperCLOVAXAudioDecoderModel( + od_config=od_config).to(self.device) self.spk_emb = self.bigvgan.spk_emb.to(self.device) self._vocab = int(getattr(self.bigvgan.h, "num_units", 0)) @@ -85,128 +121,167 @@ def __init__( speakers = SPEAKERS_LIST self.speaker_map = {spk: i for i, spk in enumerate(speakers)} + def _resolve_mar_path(self, model: str | None) -> Path | None: + if model is None: + return None + + model_path = Path(model) + if model_path.is_file() and model_path.suffix == ".mar": + return model_path + + if not model_path.is_dir(): + return None + + candidates = [ + model_path / "NCCosybigvganDecoder.mar", + model_path / "NCZSCosybigvganDecoder.mar", + model_path / "decoder" / "audio" / "NCCosybigvganDecoder.mar", + model_path / "decoder" / "audio" / "NCZSCosybigvganDecoder.mar", + ] + for candidate in candidates: + if candidate.exists(): + return candidate + return None + + def _extract_mar_checkpoint(self, mar_path: Path) -> tuple[str, str]: + extract_dir = Path(tempfile.mkdtemp(prefix="hcx_audio_decoder_")) + self._mar_extract_dir = str(extract_dir) + + with zipfile.ZipFile(mar_path) as zf: + manifest = json.loads(zf.read("MAR-INF/MANIFEST.json")) + serialized_file = manifest.get("model", {}).get("serializedFile") + if not serialized_file: + raise ValueError(f"serializedFile not found in {mar_path}") + + zf.extract(serialized_file, path=extract_dir) + zf.extract("config.json", path=extract_dir) + + return str(extract_dir / serialized_file), str(extract_dir / "config.json") + def _prepare_batch( - self, audio_tokens: list[list[int]], speakers: list[str], formats: list[str], ref_audio_tokens: list[str] - ) -> list[tuple[torch.Tensor, torch.Tensor, str]]: - """ - Construct batch to forward through the model. - - Args: - - audio_tokens: List[List[int]]: discrete audio tokens to decode. - - speakers: List[str]: speaker IDs for output audio. - - formats: List[str]: output audio formats. - - ref_audio_tokens: List[str]: - List of base64 encoded reference audio. - If provided, speaker and format will be ignored. - - Returns: - batch: List of tuples of (audio_tokens, speaker_id or ref_mel, format) - """ + self, + audio_tokens: list[list[int]], + speakers: list[str], + formats: list[str], + ref_audio_tokens: list[str | None], + ) -> list[tuple[torch.Tensor, torch.Tensor, str | None]]: batch = [] - for units, speaker, fmt, ref_audio in zip(audio_tokens, speakers, formats, ref_audio_tokens): + for units, speaker, fmt, ref_audio in zip( + audio_tokens, speakers, formats, ref_audio_tokens): units = torch.tensor(units, dtype=torch.long, device=self.device) if self._vocab > 0: mask = (units < 0) | (units >= self._vocab) if mask.any(): bad_idxs = units[mask].tolist() - raise ValueError(f"Unit indices out of range [0-{self._vocab - 1}]: {bad_idxs}") + raise ValueError( + f"Unit indices out of range " + f"[0-{self._vocab - 1}]: {bad_idxs}") - if ref_audio is not None: - ref_audio_bytes = base64.b64decode(ref_audio.encode("ascii"), validate=True) + if ref_audio is not None and not self.bigvgan.finetune: + ref_audio_bytes = base64.b64decode( + ref_audio.encode("ascii"), validate=True) ref_mel = ( - self._get_reference_mel_spectrogram(ref_audio_bytes, self.bigvgan.h).to(self.device).to(self._dtype) - ) + self._get_reference_mel_spectrogram( + ref_audio_bytes, self.bigvgan.h) + .to(self.device).to(self._dtype)) + batch.append((units, ref_mel, None)) + elif ref_audio is None and not self.bigvgan.finetune: + # Zero-shot decoder (ECAPA-TDNN) with no reference: use zero + # mel as fallback so text-only requests don't crash. + n_mels = int(getattr(self.bigvgan.h, "num_mels", 100)) + ref_mel = torch.zeros(1, n_mels, 64, + device=self.device, dtype=self._dtype) batch.append((units, ref_mel, None)) else: speaker = "fkms" if speaker is None else speaker - fmt = DEFAULT_FORMAT.lower() if fmt is None else fmt.lower() + fmt = (DEFAULT_FORMAT.lower() if fmt is None + else fmt.lower()) if fmt not in FORMAT_MIME_MAP: - raise ValueError(f"Unsupported format '{fmt}'. Choose from {list(FORMAT_MIME_MAP)}") - speaker_id = torch.tensor([self.speaker_map[speaker]], dtype=torch.long) + raise ValueError( + f"Unsupported format '{fmt}'. " + f"Choose from {list(FORMAT_MIME_MAP)}") + speaker_id = torch.tensor( + [self.speaker_map[speaker]], dtype=torch.long) speaker_id = speaker_id.unsqueeze(0).to(self.device) - batch.append((units, speaker_id, fmt)) return batch def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: - """ - Generate audio from audio tokens. - - Args: - req: OmniDiffusionRequest must containing: - - extra["audio_tokens"]: List[List[int]]: [B, L] or [L, ] audio token ids. - - extra["speakers"]: List[str]: speaker for each audio sample. - - extra["formats"]: List[str]: output audio format for each audio sample. - - extra["ref_audio_tokens"]: List[str]: base64 encoded reference audio for each audio sample. - - Returns: - OmniDiffusionResponse: The diffusion response. - """ - - # 1. Validate inputs exist in request audio_tokens = req.extra.get("audio_tokens") if audio_tokens is None: - return DiffusionOutput(output=None, error="audio_tokens required in req.extra") + return DiffusionOutput( + output=None, error="audio_tokens required in req.extra") - speakers = req.extra.get("speakers") - if speakers is None: - return DiffusionOutput(output=None, error="speakers required in req.extra") + # Default speakers to "fkms" for each sample when not provided + # (e.g., when called from the pipeline stage processor). + speakers = req.extra.get( + "speakers", ["fkms"] * len(audio_tokens)) if len(audio_tokens) != len(speakers): - return DiffusionOutput(output=None, error="length of speakers and audio_tokens must be the same") + return DiffusionOutput( + output=None, + error="length of speakers and audio_tokens must be the same") - # Optional: audio format. If not provided, use wav format as default. - formats = req.extra.get("formats", [DEFAULT_FORMAT.lower()] * len(audio_tokens)) + formats = req.extra.get( + "formats", [DEFAULT_FORMAT.lower()] * len(audio_tokens)) if len(audio_tokens) != len(formats): - return DiffusionOutput(output=None, error="length of formats and audio_tokens must be the same") - - ref_audio_tokens = req.extra.get("ref_audio_tokens") + return DiffusionOutput( + output=None, + error="length of formats and audio_tokens must be the same") + + # BUG FIX #2: Original PR #869 didn't handle None ref_audio_tokens, + # causing len(None) TypeError + ref_audio_tokens = req.extra.get( + "ref_audio_tokens", [None] * len(audio_tokens)) if len(audio_tokens) != len(ref_audio_tokens): - return DiffusionOutput(output=None, error="length of ref_audio_tokens and audio_tokens must be the same") + return DiffusionOutput( + output=None, + error="length of ref_audio_tokens and audio_tokens " + "must be the same") - # 2. Construct batch from given request inputs - batch = self._prepare_batch(audio_tokens, speakers, formats, ref_audio_tokens) + batch = self._prepare_batch( + audio_tokens, speakers, formats, ref_audio_tokens) results: list[tuple[torch.Tensor, str]] = [] - for units, speaker, fmt in batch: - # 3. Convert to tensor if needed + for units, speaker_or_mel, fmt in batch: if isinstance(units, list): units = torch.tensor(units, dtype=torch.long) elif isinstance(units, np.ndarray): units = torch.from_numpy(units).long() if len(units.size()) == 2 and units.size(0) == 1: - return DiffusionOutput(output=None, error="the underlying decoder does not support batch inference yet") + return DiffusionOutput( + output=None, + error="the underlying decoder does not support " + "batch inference yet") - units = units.unsqueeze(0) - units = units.to(self.device) + units = units.unsqueeze(0).to(self.device) padded_unit, original_portion = self.pad(units) - # 4. Generate speaker embedding - spk_emb = self.spk_emb(speaker) + if fmt is None: + # ref_audio path: speaker_or_mel is a mel spectrogram + # (float tensor). Only works with ECAPA_TDNN (finetune=False). + if self.bigvgan.finetune: + return DiffusionOutput( + output=None, + error="Reference audio requires finetune=False " + "(ECAPA_TDNN speaker encoder)") + spk_emb = self.spk_emb(speaker_or_mel) + else: + # speaker_id path: speaker_or_mel is a LongTensor + spk_emb = self.spk_emb(speaker_or_mel) - # 5. Decode audio padded_out, hidden = self.bigvgan(padded_unit, spk_emb=spk_emb) del hidden out = self.unpad(padded_out, original_portion) - # 6. Append decoded audio to result results.append((out.to(torch.float32), fmt)) - return DiffusionOutput( - output=results, post_process_func=get_hyperclovax_audio_post_process_func(self.od_config) - ) + return DiffusionOutput(output=results) def pad(self, unit: torch.Tensor) -> tuple[torch.Tensor, float]: - """ - Pad the `unit` tensor to AUDIOLLM_PAD_MULTIPLE environment variable. - - Args: - unit: int tensor of shape [1, L] - """ - pad_multiple = self._get_pad_multiple() if not pad_multiple: return unit, 1.0 @@ -215,183 +290,146 @@ def pad(self, unit: torch.Tensor) -> tuple[torch.Tensor, float]: if pad_token_id is None: return unit, 1.0 + # BUG FIX #4: Original PR #869 always padded, even when already aligned. + # When overflow==0, pad_amount was pad_multiple instead of 0. overflow = unit.shape[1] % pad_multiple + if overflow == 0: + return unit, 1.0 pad_amount = pad_multiple - overflow - padded = torch.nn.functional.pad(unit, (0, pad_amount), mode="constant", value=pad_token_id) + padded = torch.nn.functional.pad( + unit, (0, pad_amount), mode="constant", value=pad_token_id) return padded, unit.shape[-1] / padded.shape[-1] def unpad(self, x: torch.Tensor, original_portion: float) -> torch.Tensor: - """ - Unpad the `x` tensor by retaining only the `original_portion`. - - Args: - x: tensor of shape [..., T] - original_portion: ratio of original unit length over padded unit length - """ - return x[..., : math.ceil(x.shape[-1] * original_portion)] - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - """ - Load model weights using AutoWeightsLoader. - """ + return x[..., :math.ceil(x.shape[-1] * original_portion)] + + def load_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> set[str]: + # MAR checkpoint path already loads bigvgan weights eagerly. + if self._using_mar_checkpoint: + # Weights already loaded in __init__ via MAR extraction. + # Return all parameter names to pass the strict loading check. + return {name for name, _ in self.named_parameters()} + loader = AutoWeightsLoader(self) return loader.load_weights(weights) def _get_pad_multiple(self) -> int | None: - pad_multiple_str = os.getenv("AUDIOLLM_PAD_MULTIPLE", 100) + pad_multiple_str = os.getenv("AUDIOLLM_PAD_MULTIPLE", "100") if not pad_multiple_str: return None - try: pad_multiple = int(pad_multiple_str) except ValueError: - logger.warning("AUDIOLLM_PAD_MULTIPLE environment variable is not a valid int. Skipping padding...") + logger.warning("AUDIOLLM_PAD_MULTIPLE is not a valid int.") return None - if pad_multiple <= 0: - logger.warning("AUDIOLLM_PAD_MULTIPLE environment variable is not a positive int. Skipping padding...") return None - return pad_multiple def _get_pad_token_id(self) -> int | None: - pad_token_id_str = os.getenv("AUDIOLLM_PAD_TOKEN_ID", 3894) + pad_token_id_str = os.getenv("AUDIOLLM_PAD_TOKEN_ID", "3894") if not pad_token_id_str: - logger.warning("AUDIOLLM_PAD_TOKEN_ID environment variable is not set. Skipping padding...") return None - try: pad_token_id = int(pad_token_id_str) except ValueError: - logger.warning("AUDIOLLM_PAD_TOKEN_ID environment variable is not a valid int. Skipping padding...") + logger.warning("AUDIOLLM_PAD_TOKEN_ID is not a valid int.") return None - if pad_token_id < 0: - logger.warning("AUDIOLLM_PAD_TOKEN_ID environment variable is a negative int. Skipping padding...") return None - return pad_token_id def _get_down_sample_rate(self) -> float | None: down_sample_rate_str = os.getenv("AUDIOLLM_DOWN_SAMPLE_RATE") if not down_sample_rate_str: return None - try: down_sample_rate = float(down_sample_rate_str) except ValueError: - logger.warning( - "AUDIOLLM_DOWN_SAMPLE_RATE environment variable is not a valid float. Skipping down-sampling..." - ) return None - if down_sample_rate <= 0: - logger.warning( - "AUDIOLLM_DOWN_SAMPLE_RATE environment variable is not a positive float. Skipping down-sampling..." - ) return None - return down_sample_rate def _detect_audio_format(self, header_bytes: bytes) -> str | None: - """ - Detect audio format from header bytes of audio file. - - Args: - header_bytes: first 4 bytes of audio file. - """ for prefix_bytes, fmt in AUDIO_FORMAT_MAP: if header_bytes.startswith(prefix_bytes): return fmt return None - def _hpf_normalize(self, pcm: np.ndarray, sr: int | float, volume_level: float) -> np.ndarray: - assert (pcm**2).mean() > 0, "Error in the wav file" + def _hpf_normalize(self, pcm: np.ndarray, sr: int | float, + volume_level: float) -> np.ndarray: + assert (pcm ** 2).mean() > 0, "Error in the wav file" assert np.issubdtype(pcm.dtype, np.floating) - - # highpass filter filter_ = scipy.signal.butter(2, 70, "highpass", fs=sr, output="sos") pcm = scipy.signal.sosfilt(filter_, pcm) pcm = pcm.astype(np.float32) - - # volume normalize - gain = min(volume_level / (pcm**2).mean() ** 0.5, 1 / np.max(np.abs(pcm))) + gain = min(volume_level / (pcm ** 2).mean() ** 0.5, + 1 / np.max(np.abs(pcm))) pcm *= gain return pcm - def _load_reference_audio(self, audio: bytes, sample_rate: float) -> np.ndarray: - audio = io.BytesIO(audio) - fmt = self._detect_audio_format(audio[:4]) + def _load_reference_audio(self, audio: bytes, + sample_rate: float) -> np.ndarray: + # BUG FIX #3: Original PR #869 tried audio[:4] on BytesIO object. + # Must read header bytes BEFORE wrapping in BytesIO. + header = audio[:4] + audio_io = io.BytesIO(audio) + fmt = self._detect_audio_format(header) if fmt: - segment = pydub.AudioSegment.from_file(audio, format=fmt) + segment = pydub.AudioSegment.from_file(audio_io, format=fmt) else: - segment = pydub.AudioSegment.from_file(audio) + segment = pydub.AudioSegment.from_file(audio_io) wav_file = io.BytesIO() segment.export(wav_file, format="wav") wav_file.seek(0) - # Down-sample to reduce noise in final result. load_sr = self._get_down_sample_rate() if load_sr is None: load_sr = sample_rate pcm, sr = librosa.load(wav_file, sr=load_sr, mono=True) pcm = librosa.resample(pcm, orig_sr=sr, target_sr=sample_rate) - pcm = self._hpf_normalize(pcm, sample_rate, VOLUME_LEVEL) return pcm - def _compute_mel_spectrogram(self, y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + def _compute_mel_spectrogram( + self, y, n_fft, num_mels, sampling_rate, hop_size, win_size, + fmin, fmax, center=False, + ): global mel_basis, hann_window - # Create a unique key based on fmax and device key = f"{fmax}_{y.device}" if key not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, + n_mels=num_mels, fmin=fmin, fmax=fmax) mel_basis[key] = torch.from_numpy(mel).float().to(y.device) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to( + y.device) - # Pad the signal for STFT pad_amount = int((n_fft - hop_size) / 2) - y = torch.nn.functional.pad(y.unsqueeze(1), (pad_amount, pad_amount), mode="reflect").squeeze(1) + y = torch.nn.functional.pad( + y.unsqueeze(1), (pad_amount, pad_amount), + mode="reflect").squeeze(1) - # Compute the Short-Time Fourier Transform (STFT) spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - - # Compute the magnitude spectrogram with a small epsilon to avoid log(0) + y, n_fft, hop_length=hop_size, win_length=win_size, + window=hann_window[str(y.device)], center=center, + pad_mode="reflect", normalized=False, onesided=True, + return_complex=True) spec = torch.sqrt(torch.real(spec * spec.conj() + 1e-9)) - - # Map the linear-frequency spectrogram to the mel scale spec = torch.matmul(mel_basis[key], spec) - - # Apply spectral normalization (dynamic range compression) spec = torch.log(torch.clamp(spec, min=1e-5)) - return spec - def _get_reference_mel_spectrogram(self, ref_audio: bytes, h: dict[str, Any]) -> torch.Tensor: + def _get_reference_mel_spectrogram( + self, ref_audio: bytes, h: dict[str, Any] + ) -> torch.Tensor: pcm = self._load_reference_audio(ref_audio, h.sampling_rate) pcm = torch.from_numpy(pcm).unsqueeze(0) - mel = self._compute_mel_spectrogram( - pcm, - h.n_fft, - h.num_mels, - h.sampling_rate, - h.hop_size, - h.win_size, - h.fmin, - h.fmax, - ) + pcm, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, + h.win_size, h.fmin, h.fmax) return mel diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 2fc2121fa9a..501f1e54934 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -3,9 +3,15 @@ import importlib +import torch.nn as nn +from vllm.logger import init_logger from vllm.model_executor.models.registry import _LazyRegisteredModel, _ModelRegistry from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig, get_sp_plan_from_model +from vllm_omni.diffusion.hooks.sequence_parallel import apply_sequence_parallel + +logger = init_logger(__name__) _DIFFUSION_MODELS = { # arch:(mod_folder, mod_relname, cls_name) @@ -29,6 +35,11 @@ "pipeline_qwen_image_layered", "QwenImageLayeredPipeline", ), + "GlmImagePipeline": ( + "glm_image", + "pipeline_glm_image", + "GlmImagePipeline", + ), "ZImagePipeline": ( "z_image", "pipeline_z_image", @@ -74,11 +85,36 @@ "pipeline_sd3", "StableDiffusion3Pipeline", ), + "HunyuanImage3ForCausalMM": ( + "hunyuan_image_3", + "pipeline_hunyuan_image_3", + "HunyuanImage3Pipeline", + ), "Flux2KleinPipeline": ( "flux2_klein", "pipeline_flux2_klein", "Flux2KleinPipeline", ), + "NextStep11Pipeline": ( + "nextstep_1_1", + "pipeline_nextstep_1_1", + "NextStep11Pipeline", + ), + "FluxPipeline": ( + "flux", + "pipeline_flux", + "FluxPipeline", + ), + "OmniGen2Pipeline": ( + "omnigen2", + "pipeline_omnigen2", + "OmniGen2Pipeline", + ), + "HyperCLOVAXVisionPipeline": ( + "hyperclovax_vision", + "pipeline_hyperclovax_vision", + "HyperCLOVAXVisionPipeline", + ), "HyperCLOVAXAudioPipeline": ( "hyperclovax_audio", "pipeline_hyperclovax_audio", @@ -97,24 +133,158 @@ } ) +_VAE_PATCH_PARALLEL_ALLOWLIST = { + # Only enable for models we have validated end-to-end. + "StableDiffusion3Pipeline", + "ZImagePipeline", + "NextStep11Pipeline", +} + +_NO_CACHE_ACCELERATION = { + # Pipelines that do not support cache acceleration (cache_dit / tea_cache). + "NextStep11Pipeline", +} + def initialize_model( od_config: OmniDiffusionConfig, -): +) -> nn.Module: + """Initialize a diffusion model from the registry. + + This function: + 1. Loads the model class from the registry + 2. Instantiates the model with the config + 3. Configures VAE optimization settings + 4. Applies sequence parallelism if enabled (similar to diffusers' enable_parallelism) + + Args: + od_config: The OmniDiffusion configuration. + + Returns: + The initialized pipeline model. + + Raises: + ValueError: If the model class is not found in the registry. + """ model_class = DiffusionModelRegistry._try_load_model_cls(od_config.model_class_name) if model_class is not None: model = model_class(od_config=od_config) + + vae_pp_size = od_config.parallel_config.vae_patch_parallel_size + if vae_pp_size > 1 and od_config.model_class_name not in _VAE_PATCH_PARALLEL_ALLOWLIST: + logger.warning( + "vae_patch_parallel_size=%d is set but VAE patch parallelism is only enabled for %s; ignoring.", + vae_pp_size, + sorted(_VAE_PATCH_PARALLEL_ALLOWLIST), + ) + if ( + vae_pp_size > 1 + and od_config.model_class_name in _VAE_PATCH_PARALLEL_ALLOWLIST + and not od_config.vae_use_tiling + ): + logger.info( + "vae_patch_parallel_size=%d requires vae_use_tiling; automatically enabling it.", + vae_pp_size, + ) + od_config.vae_use_tiling = True + # Configure VAE memory optimization settings from config - if hasattr(model.vae, "use_slicing"): + if hasattr(model, "vae") and hasattr(model.vae, "use_slicing"): model.vae.use_slicing = od_config.vae_use_slicing - if hasattr(model.vae, "use_tiling"): + if hasattr(model, "vae") and hasattr(model.vae, "use_tiling"): model.vae.use_tiling = od_config.vae_use_tiling + if ( + vae_pp_size > 1 + and hasattr(model, "vae") + and od_config.model_class_name in _VAE_PATCH_PARALLEL_ALLOWLIST + and od_config.vae_use_tiling + ): + from vllm_omni.diffusion.distributed.parallel_state import get_dit_group + from vllm_omni.diffusion.distributed.vae_patch_parallel import maybe_wrap_vae_decode_with_patch_parallelism + + maybe_wrap_vae_decode_with_patch_parallelism( + model, + vae_patch_parallel_size=vae_pp_size, + group_getter=get_dit_group, + ) + + # Apply sequence parallelism if enabled + # This follows diffusers' pattern where enable_parallelism() is called + # at model loading time, not inside individual model files + _apply_sequence_parallel_if_enabled(model, od_config) + return model else: raise ValueError(f"Model class {od_config.model_class_name} not found in diffusion model registry.") +def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) -> None: + """Apply sequence parallelism hooks if SP is enabled. + + This is the centralized location for enabling SP, similar to diffusers' + ModelMixin.enable_parallelism() method. It applies _sp_plan hooks to + transformer models that define them. + + Note: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in diffusers. + We use _sp_plan instead of diffusers' _cp_plan. + + Args: + model: The pipeline model (e.g., ZImagePipeline). + od_config: The OmniDiffusion configuration. + """ + + try: + sp_size = od_config.parallel_config.sequence_parallel_size + if sp_size <= 1: + return + + # Find transformer model(s) in the pipeline that have _sp_plan + # Include transformer_2 for two-stage models (e.g., Wan MoE) + transformer_attrs = ["transformer", "transformer_2", "dit", "unet"] + applied_count = 0 + + for attr in transformer_attrs: + if not hasattr(model, attr): + continue + + transformer = getattr(model, attr) + if transformer is None: + continue + + plan = get_sp_plan_from_model(transformer) + if plan is None: + continue + + # Create SP config + sp_config = SequenceParallelConfig( + ulysses_degree=od_config.parallel_config.ulysses_degree, + ring_degree=od_config.parallel_config.ring_degree, + ) + + # Apply hooks according to the plan + mode = ( + "hybrid" + if sp_config.ulysses_degree > 1 and sp_config.ring_degree > 1 + else ("ulysses" if sp_config.ulysses_degree > 1 else "ring") + ) + logger.info( + f"Applying sequence parallelism to {transformer.__class__.__name__} ({attr}) " + f"(sp_size={sp_size}, mode={mode}, ulysses={sp_config.ulysses_degree}, ring={sp_config.ring_degree})" + ) + apply_sequence_parallel(transformer, sp_config, plan) + applied_count += 1 + + if applied_count == 0: + logger.warning( + f"Sequence parallelism is enabled (sp_size={sp_size}) but no transformer with _sp_plan found. " + "SP hooks not applied. Consider adding _sp_plan to your transformer model." + ) + + except Exception as e: + logger.warning(f"Failed to apply sequence parallelism: {e}. Continuing without SP hooks.") + + _DIFFUSION_POST_PROCESS_FUNCS = { # arch: post_process_func # `post_process_func` function must be placed in {mod_folder}/{mod_relname}.py, @@ -122,6 +292,7 @@ def initialize_model( "QwenImagePipeline": "get_qwen_image_post_process_func", "QwenImageEditPipeline": "get_qwen_image_edit_post_process_func", "QwenImageEditPlusPipeline": "get_qwen_image_edit_plus_post_process_func", + "GlmImagePipeline": "get_glm_image_post_process_func", "ZImagePipeline": "get_post_process_func", "OvisImagePipeline": "get_ovis_image_post_process_func", "WanPipeline": "get_wan22_post_process_func", @@ -132,6 +303,10 @@ def initialize_model( "LongCatImageEditPipeline": "get_longcat_image_post_process_func", "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", "Flux2KleinPipeline": "get_flux2_klein_post_process_func", + "NextStep11Pipeline": "get_nextstep11_post_process_func", + "FluxPipeline": "get_flux_post_process_func", + "OmniGen2Pipeline": "get_omnigen2_post_process_func", + "HyperCLOVAXVisionPipeline": "get_hyperclovax_vision_post_process_func", "HyperCLOVAXAudioPipeline": "get_hyperclovax_audio_post_process_func", } @@ -139,12 +314,14 @@ def initialize_model( # arch: pre_process_func # `pre_process_func` function must be placed in {mod_folder}/{mod_relname}.py, # where mod_folder and mod_relname are defined and mapped using `_DIFFUSION_MODELS` via the `arch` key + "GlmImagePipeline": "get_glm_image_pre_process_func", "QwenImageEditPipeline": "get_qwen_image_edit_pre_process_func", "QwenImageEditPlusPipeline": "get_qwen_image_edit_plus_pre_process_func", "LongCatImageEditPipeline": "get_longcat_image_edit_pre_process_func", "QwenImageLayeredPipeline": "get_qwen_image_layered_pre_process_func", "WanPipeline": "get_wan22_pre_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func", + "OmniGen2Pipeline": "get_omnigen2_pre_process_func", } diff --git a/vllm_omni/diffusion/request.py b/vllm_omni/diffusion/request.py index cd31aba6737..94229fe86c3 100644 --- a/vllm_omni/diffusion/request.py +++ b/vllm_omni/diffusion/request.py @@ -2,12 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pprint -from dataclasses import asdict, dataclass, field -from typing import Any +import random +from dataclasses import dataclass, field -import PIL.Image -import torch +from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType @dataclass @@ -15,9 +13,8 @@ class OmniDiffusionRequest: """ Complete state passed through the pipeline execution. - This dataclass contains all information needed during the diffusion pipeline - execution, allowing methods to update specific components without needing - to manage numerous individual parameters. + This dataclass contains the prompts and sampling parameters for the diffusion pipeline + execution. It also contains a request_id for other components to trace this request and its outputs. """ # TODO(will): double check that args are separate from server_args @@ -25,164 +22,36 @@ class OmniDiffusionRequest: # specific arguments. # data_type: DataType - request_id: str | None = None + prompts: list[OmniPromptType] # Actually supporting str-based prompts + sampling_params: OmniDiffusionSamplingParams - generator: torch.Generator | list[torch.Generator] | None = None - - # Image inputs - image_path: str | None = None - # Image encoder hidden states - image_embeds: list[torch.Tensor] = field(default_factory=list) - pil_image: torch.Tensor | PIL.Image.Image | None = None - pixel_values: torch.Tensor | PIL.Image.Image | None = None - preprocessed_image: torch.Tensor | None = None - - # Text inputs - prompt: str | list[str] | None = None - negative_prompt: str | list[str] | None = None - prompt_path: str | None = None - output_path: str = "outputs/" - # without extension - output_file_name: str | None = None - output_file_ext: str | None = None - # Primary encoder embeddings - prompt_embeds: list[torch.Tensor] | torch.Tensor = field(default_factory=list) - negative_prompt_embeds: list[torch.Tensor] | None = None - prompt_attention_mask: list[torch.Tensor] | None = None - negative_attention_mask: list[torch.Tensor] | None = None - clip_embedding_pos: list[torch.Tensor] | None = None - clip_embedding_neg: list[torch.Tensor] | None = None - - pooled_embeds: list[torch.Tensor] = field(default_factory=list) - neg_pooled_embeds: list[torch.Tensor] = field(default_factory=list) - - # Additional text-related parameters - max_sequence_length: int | None = None - prompt_template: dict[str, Any] | None = None - do_classifier_free_guidance: bool = False - - # Batch info - num_outputs_per_prompt: int = 1 - seed: int | None = None - seeds: list[int] | None = None - - # layered info - layers: int = 4 - - # cfg info - cfg_normalize: bool = False - - # caption language - use_en_prompt: bool = False - - # different bucket in (640, 1024) to determine the condition and output resolution - resolution: int = 640 - - # Tracking if embeddings are already processed - is_prompt_processed: bool = False - - # Latent tensors - latents: torch.Tensor | None = None - raw_latent_shape: torch.Tensor | None = None - noise_pred: torch.Tensor | None = None - image_latent: torch.Tensor | None = None - - # Latent dimensions - height_latents: list[int] | int | None = None - width_latents: list[int] | int | None = None - num_frames: list[int] | int = 1 # Default for image models - num_frames_round_down: bool = False # Whether to round down num_frames if it's not divisible by num_gpus - - # Original dimensions (before VAE scaling) - height: list[int] | int | None = None - width: list[int] | int | None = None - fps: list[int] | int | None = None - height_not_provided: bool = False - width_not_provided: bool = False - - # Timesteps - timesteps: torch.Tensor | None = None - timestep: torch.Tensor | float | int | None = None - step_index: int | None = None - boundary_ratio: float | None = None - - # Scheduler parameters - num_inference_steps: int = 50 - guidance_scale: float = 1.0 - guidance_scale_provided: bool = False - guidance_scale_2: float | None = None - guidance_rescale: float = 0.0 - eta: float = 0.0 - sigmas: list[float] | None = None - - true_cfg_scale: float | None = None # qwen-image specific now - - n_tokens: int | None = None - - # Other parameters that may be needed by specific schedulers - extra_step_kwargs: dict[str, Any] = field(default_factory=dict) - - # Component modules (populated by the pipeline) - modules: dict[str, Any] = field(default_factory=dict) - - return_trajectory_latents: bool = False - return_trajectory_decoded: bool = False - trajectory_timesteps: list[torch.Tensor] | None = None - trajectory_latents: torch.Tensor | None = None - - # Extra parameters that might be needed by specific pipeline implementations - extra: dict[str, Any] = field(default_factory=dict) - - # Misc - save_output: bool = True - return_frames: bool = False - - # STA parameters - STA_param: list | None = None - is_cfg_negative: bool = False - mask_search_final_result_pos: list[list] | None = None - mask_search_final_result_neg: list[list] | None = None - - # VSA parameters - VSA_sparsity: float = 0.0 - # perf_logger: PerformanceLogger | None = None - - # stage logging - # logging_info: PipelineLoggingInfo = field(default_factory=PipelineLoggingInfo) - - # profile - profile: bool = False - num_profiled_timesteps: int = 8 - - # debugging - debug: bool = False - - # results - output: torch.Tensor | None = None - - @property - def batch_size(self): - # Determine batch size - if isinstance(self.prompt, list): - batch_size = len(self.prompt) - elif self.prompt is not None: - batch_size = 1 - else: - batch_size = self.prompt_embeds[0].shape[0] - - # Adjust batch size for number of videos per prompt - batch_size *= self.num_outputs_per_prompt - return batch_size + request_ids: list[str] = field(default_factory=list) + extra: dict = field(default_factory=dict) # Additional data from stage input processors (e.g. vision_tokens, audio_tokens) def __post_init__(self): """Initialize dependent fields after dataclass initialization.""" - # Set do_classifier_free_guidance based on guidance scale and negative prompt - if self.guidance_scale > 1.0 and self.negative_prompt is not None: - self.do_classifier_free_guidance = True - if self.negative_prompt_embeds is None: - self.negative_prompt_embeds = [] - if self.guidance_scale_2 is None: - self.guidance_scale_2 = self.guidance_scale + # When neither a generator nor a seed is provided, assign a random seed + # so that all ranks derive the same generator state. + if self.sampling_params.generator is None and self.sampling_params.seed is None: + self.sampling_params.seed = random.randint(0, 2**31 - 1) + + # Detect whether user explicitly provided guidance_scale. + # The sentinel default is 0.0 (false-like); any truthy value means + # the caller set it intentionally. We must resolve this BEFORE + # auto-filling guidance_scale_2, otherwise the sentinel leaks into + # guidance_scale_2. + if self.sampling_params.guidance_scale: + self.sampling_params.guidance_scale_provided = True + else: + self.sampling_params.guidance_scale = 1.0 - def __str__(self): - return pprint.pformat(asdict(self), indent=2, width=120) + # Set do_classifier_free_guidance based on guidance scale and negative prompt + if self.sampling_params.guidance_scale > 1.0 and any( + (not isinstance(p, str) and p.get("negative_prompt")) for p in self.prompts + ): + self.sampling_params.do_classifier_free_guidance = True + + # Auto-fill guidance_scale_2 from the (now-resolved) guidance_scale + # so downstream code always has a valid value. + if self.sampling_params.guidance_scale_2 is None: + self.sampling_params.guidance_scale_2 = self.sampling_params.guidance_scale diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py new file mode 100644 index 00000000000..f589f370661 --- /dev/null +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -0,0 +1,751 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Diffusion Worker for vLLM-Omni. + +Handles GPU infrastructure initialization and delegates model operations +to DiffusionModelRunner. +""" + +import gc +import multiprocessing as mp +import os +from collections.abc import Iterable +from contextlib import AbstractContextManager, nullcontext +from typing import Any + +import torch +import zmq +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.logger import init_logger +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.mem_utils import GiB_bytes +from vllm.v1.worker.workspace import init_workspace_manager + +from vllm_omni.diffusion.data import ( + DiffusionOutput, + OmniDiffusionConfig, +) +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + init_distributed_environment, + initialize_model_parallel, +) +from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.diffusion.ipc import pack_diffusion_output_shm +from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput +from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner +from vllm_omni.diffusion.worker.utils import RunnerOutput +from vllm_omni.lora.request import LoRARequest +from vllm_omni.platforms import current_omni_platform +from vllm_omni.profiler import OmniTorchProfilerWrapper, create_omni_profiler +from vllm_omni.worker.gpu_memory_utils import get_process_gpu_memory + +logger = init_logger(__name__) + + +class DiffusionWorker: + """ + A worker that manages GPU infrastructure and delegates to the model runner. + + This class handles infrastructure initialization only: + - Device setup (CUDA device selection) + - Distributed environment (NCCL, model parallel) + - Memory management (sleep/wake) + + All model-related operations (loading, compilation, execution) are + delegated to DiffusionModelRunner. + """ + + def __init__( + self, + local_rank: int, + rank: int, + od_config: OmniDiffusionConfig, + skip_load_model: bool = False, + ): + self.local_rank = local_rank + self.rank = rank + self.od_config = od_config + self.device: torch.device | None = None + self.vllm_config: VllmConfig | None = None + self.model_runner: DiffusionModelRunner | None = None + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + self.lora_manager: DiffusionLoRAManager | None = None + self.init_device() + # Create model runner + self.model_runner = DiffusionModelRunner( + vllm_config=self.vllm_config, + od_config=self.od_config, + device=self.device, + ) + # Initialize profiler if configured + self.profiler: OmniTorchProfilerWrapper | None = None + profiler_config = self.od_config.profiler_config + if profiler_config and profiler_config.profiler == "torch": + self.profiler = create_omni_profiler( + profiler_config=profiler_config, + worker_name=f"diffusion_worker_{self.rank}", + local_rank=self.local_rank, + ) + if not skip_load_model: + self.load_model(load_format=self.od_config.diffusion_load_format) + self.init_lora_manager() + logger.info(f"Worker {self.rank}: Initialization complete.") + + def init_device(self) -> None: + """Initialize the device and distributed environment.""" + import logging as _l + _l.getLogger(__name__).warning('[DEBUG-DEVICE] CUDA_VISIBLE_DEVICES=%s, rank=%s', os.environ.get('CUDA_VISIBLE_DEVICES', 'NOT_SET'), getattr(self, 'rank', 'N/A')) + world_size = self.od_config.num_gpus + rank = self.rank + + # Set environment variables for distributed initialization + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(self.od_config.master_port) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + # Setup device + self.device = current_omni_platform.get_torch_device(rank) + current_omni_platform.set_device(self.device) + + # Create vllm_config for parallel configuration + vllm_config = VllmConfig(compilation_config=CompilationConfig()) + vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size + vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size + vllm_config.parallel_config.enable_expert_parallel = self.od_config.parallel_config.enable_expert_parallel + self.vllm_config = vllm_config + + # Initialize distributed environment + with ( + set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config), + set_current_vllm_config(self.vllm_config), + ): + init_distributed_environment(world_size=world_size, rank=rank) + logger.info(f"Worker {self.rank}: Initialized device and distributed environment.") + + parallel_config = self.od_config.parallel_config + initialize_model_parallel( + data_parallel_size=parallel_config.data_parallel_size, + cfg_parallel_size=parallel_config.cfg_parallel_size, + sequence_parallel_size=parallel_config.sequence_parallel_size, + ulysses_degree=parallel_config.ulysses_degree, + ring_degree=parallel_config.ring_degree, + tensor_parallel_size=parallel_config.tensor_parallel_size, + pipeline_parallel_size=parallel_config.pipeline_parallel_size, + fully_shard_degree=parallel_config.hsdp_shard_size if parallel_config.use_hsdp else 1, + hsdp_replicate_size=parallel_config.hsdp_replicate_size if parallel_config.use_hsdp else 1, + enable_expert_parallel=parallel_config.enable_expert_parallel, + ) + init_workspace_manager(self.device) + + def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None) -> None: + """Load the diffusion model using DiffusionModelRunner.""" + with ( + set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config), + set_current_vllm_config(self.vllm_config), + ): + self.model_runner.load_model( + memory_pool_context_fn=self._maybe_get_memory_pool_context, + load_format=load_format, + custom_pipeline_name=custom_pipeline_name, + ) + process_memory = get_process_gpu_memory(self.local_rank) + if process_memory is not None: + logger.info( + "Worker %d: Process-scoped GPU memory after model loading: %.2f GiB.", + self.rank, + process_memory / GiB_bytes, + ) + + # When load_format is "dummy", pipeline will init with custom pipeline later + if load_format != "dummy": + assert self.model_runner.pipeline is not None + + def init_lora_manager(self) -> None: + """Initialize the LoRA manager for this worker.""" + if self.model_runner.pipeline is None: + return + self.lora_manager = DiffusionLoRAManager( + pipeline=self.model_runner.pipeline, + device=self.device, + dtype=self.od_config.dtype, + max_cached_adapters=self.od_config.max_cpu_loras, + lora_path=self.od_config.lora_path, + lora_scale=self.od_config.lora_scale, + ) + + def generate(self, request: OmniDiffusionRequest) -> DiffusionOutput: + """Generate output for the given requests.""" + return self.execute_model(request, self.od_config) + + def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: + """Start or stop profiling for this GPU worker. + + Args: + is_start: True to start profiling, False to stop. + profile_prefix: Optional prefix for trace filename (vLLM compat). + + Note: + Matches vLLM's worker.profile() signature for consistency. + Traces are saved automatically via on_trace_ready callback. + """ + if self.profiler is None: + logger.warning("Profiler not initialized, skipping profile(%s)", is_start) + return + + if is_start: + from vllm_omni.profiler import OmniTorchProfilerWrapper + + if isinstance(self.profiler, OmniTorchProfilerWrapper): + import time + + filename = profile_prefix or f"diffusion_{int(time.time())}" + self.profiler.set_trace_filename(filename) + self.profiler.start() + else: + self.profiler.stop() + + def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput: + """Execute a forward pass by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + if self.lora_manager is not None: + try: + self.lora_manager.set_active_adapter(req.sampling_params.lora_request, req.sampling_params.lora_scale) + except Exception as exc: + if req.sampling_params.lora_request is not None: + raise + logger.warning("LoRA activation skipped: %s", exc) + return self.model_runner.execute_model(req) + + def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one diffusion step by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + if self.lora_manager is not None: + # Step mode does not support LoRA yet. Clear any previously active + # adapter first so worker-local LoRA state cannot leak in. + self.lora_manager.set_active_adapter(None) + + if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs): + raise ValueError("Step mode does not support LoRA yet.") + + return self.model_runner.execute_stepwise(scheduler_output) + + def load_weights(self, weights) -> set[str]: + """Load weights by delegating to the model runner.""" + assert self.model_runner is not None, "Model runner not initialized" + return self.model_runner.load_weights(weights) + + def remove_lora(self, adapter_id: int) -> bool: + return self.lora_manager.remove_adapter(adapter_id) + + def add_lora(self, lora_request: LoRARequest) -> bool: + # NOTE (Alex): We have not implemented the API routing + # for the frontend server yet. + return self.lora_manager.add_adapter(lora_request) + + def list_loras(self) -> list[int]: + return self.lora_manager.list_adapters() + + def pin_lora(self, adapter_id: int) -> bool: + return self.lora_manager.pin_adapter(adapter_id) + + def sleep(self, level: int = 1) -> bool: + """ + Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + """ + from vllm.device_allocator.cumem import CuMemAllocator + + process_memory_before_sleep = get_process_gpu_memory(self.local_rank) + free_bytes_before_sleep = None + if process_memory_before_sleep is None: + free_bytes_before_sleep = current_omni_platform.get_free_memory() + + # Save the buffers before level 2 sleep + if level == 2 and self.model_runner is not None: + model = self.model_runner.pipeline + self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()} + + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) + process_memory_after_sleep = get_process_gpu_memory(self.local_rank) + if process_memory_before_sleep is not None and process_memory_after_sleep is not None: + freed_bytes = process_memory_before_sleep - process_memory_after_sleep + used_bytes = process_memory_after_sleep + accounting_scope = "process-scoped" + else: + free_bytes_after_sleep = current_omni_platform.get_free_memory() + assert free_bytes_before_sleep is not None + device_id = self.device.index if self.device.index is not None else 0 + total = current_omni_platform.get_device_total_memory(device_id) + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + accounting_scope = "device-scoped fallback" + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode (%s) freed %.2f GiB memory, %.2f GiB memory is still in use.", + accounting_scope, + freed_bytes / GiB_bytes, + used_bytes / GiB_bytes, + ) + return True + + def wake_up(self, tags: list[str] | None = None) -> bool: + """ + Wake up the worker from sleep mode. See the sleep function + method for more details. + + Args: + tags: An optional list of tags to reallocate the worker memory + for specific memory allocations. Values must be in + `("weights")`. If None, all memory is reallocated. + wake_up should be called with all tags (or None) before the + worker is used again. + """ + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags) + + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers) and self.model_runner is not None: + model = self.model_runner.pipeline + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + return True + + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: + """Get memory pool context for sleep mode support.""" + if self.od_config.enable_sleep_mode: + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process." + return allocator.use_memory_pool(tag=tag) + else: + return nullcontext() + + def shutdown(self) -> None: + """Shutdown the worker and cleanup distributed environment.""" + destroy_distributed_env() + + +class CustomPipelineWorkerExtension: + def re_init_pipeline(self, custom_pipeline_args: dict[str, Any]) -> None: + """ + Re-initialize the pipeline with custom arguments. + + Args: + custom_pipeline_args: Dictionary of arguments for custom pipeline initialization + """ + + # Clean up old pipeline + if self.model_runner.pipeline is not None: + del self.model_runner.pipeline + gc.collect() + torch.cuda.empty_cache() + + # Get custom pipeline class name + custom_pipeline_name = custom_pipeline_args["pipeline_class"] + + # Use the DiffusionWorker's load_model method which handles the forward context + self.load_model( + load_format="custom_pipeline", + custom_pipeline_name=custom_pipeline_name, + ) + self.init_lora_manager() + + +class WorkerProc: + """Wrapper that runs one Worker in a separate process.""" + + def __init__( + self, + od_config: OmniDiffusionConfig, + gpu_id: int, + broadcast_handle, + worker_extension_cls: str | None = None, + custom_pipeline_args: dict[str, Any] | None = None, + ): + self.od_config = od_config + self.gpu_id = gpu_id + + # Inter-process Communication + self.context = zmq.Context(io_threads=2) + + # Initialize MessageQueue reader from handle + self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id) + + self.result_mq = None + self.result_mq_handle = None + + # Setup result sender (only for rank 0) + if gpu_id == 0: + self.result_mq = MessageQueue(n_reader=1, n_local_reader=1, local_reader_ranks=[0]) + self.result_mq_handle = self.result_mq.export_handle() + logger.info(f"Worker {gpu_id} created result MessageQueue") + + assert od_config.master_port is not None + + # Create worker using WorkerWrapperBase for extension support + self.worker = self._create_worker(gpu_id, od_config, worker_extension_cls, custom_pipeline_args) + self._running = True + + def _create_worker( + self, + gpu_id: int, + od_config: OmniDiffusionConfig, + worker_extension_cls: str | None, + custom_pipeline_args: dict[str, Any] | None = None, + ) -> DiffusionWorker: + """Create a worker instance. Override in subclasses for different worker types.""" + wrapper = WorkerWrapperBase( + gpu_id=gpu_id, + od_config=od_config, + worker_extension_cls=worker_extension_cls, + custom_pipeline_args=custom_pipeline_args, + ) + return wrapper + + def return_result(self, output: object): + """Reply to client, only on rank 0.""" + if self.result_mq is not None: + try: + pack_diffusion_output_shm(output) + except Exception as e: + logger.warning("SHM pack failed, falling back to raw enqueue: %s", e) + self.result_mq.enqueue(output) + + def recv_message(self): + """Receive messages from broadcast queue.""" + return self.mq.dequeue(indefinite=True) + + def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]: + """Execute an RPC request and indicate whether to reply.""" + method = rpc_request["method"] + args = rpc_request.get("args", ()) + kwargs = rpc_request.get("kwargs", {}) + output_rank = rpc_request.get("output_rank") + exec_all_ranks = rpc_request.get("exec_all_ranks", False) + + should_execute = exec_all_ranks or output_rank is None or output_rank == self.gpu_id + should_reply = (output_rank is None or output_rank == self.gpu_id) and self.result_mq is not None + + if not should_execute: + return None, False + + try: + # Use execute_method from WorkerWrapperBase for consistent method resolution + result = self.worker.execute_method(method, *args, **kwargs) + return result, should_reply + except Exception as e: + logger.error(f"Error executing RPC: {e}", exc_info=True) + raise e + + def worker_busy_loop(self) -> None: + """Main busy loop for Multiprocessing Workers.""" + logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory") + + while self._running: + msg = None + try: + msg = self.recv_message() + except Exception as e: + logger.error( + f"Error receiving message in worker loop: {e}", + exc_info=True, + ) + continue + + if msg is None or len(msg) == 0: + logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id) + continue + + # Route message based on type + if isinstance(msg, dict) and msg.get("type") == "rpc": + try: + result, should_reply = self.execute_rpc(msg) + if should_reply: + self.return_result(result) + except Exception as e: + logger.error(f"Error processing RPC: {e}", exc_info=True) + if self.result_mq is not None: + self.return_result(DiffusionOutput(error=str(e))) + + elif isinstance(msg, dict) and msg.get("type") == "shutdown": + logger.info("Worker %s: Received shutdown message", self.gpu_id) + self._running = False + continue + + else: + # Handle generation request + try: + output = self.worker.execute_model(msg, self.od_config) + except Exception as e: + logger.error( + f"Error executing forward in event loop: {e}", + exc_info=True, + ) + output = DiffusionOutput(error=str(e)) + + try: + self.return_result(output) + except zmq.ZMQError as e: + logger.error(f"ZMQ error sending reply: {e}") + continue + + logger.info("event loop terminated.") + try: + self.worker.shutdown() + except Exception as exc: + logger.warning("Worker %s: Shutdown encountered an error: %s", self.gpu_id, exc) + self.context.term() + + @staticmethod + def worker_main( + rank: int, + od_config: OmniDiffusionConfig, + pipe_writer: mp.connection.Connection, + broadcast_handle, + worker_extension_cls: str | None = None, + custom_pipeline_args: dict[str, Any] | None = None, + ) -> None: + """Worker initialization and execution loops.""" + from vllm_omni.plugins import load_omni_general_plugins + + load_omni_general_plugins() + worker_proc = WorkerProc( + od_config, + gpu_id=rank, + broadcast_handle=broadcast_handle, + worker_extension_cls=worker_extension_cls, + custom_pipeline_args=custom_pipeline_args, + ) + logger.info(f"Worker {rank}: Scheduler loop started.") + pipe_writer.send( + { + "status": "ready", + "result_handle": worker_proc.result_mq_handle if rank == 0 else None, + } + ) + worker_proc.worker_busy_loop() + logger.info(f"Worker {rank}: Shutdown complete.") + + +class WorkerWrapperBase: + """ + Wrapper base class that creates DiffusionWorker with optional worker_extension_cls support. + This enables dynamic inheritance for DiffusionWorker to extend with custom functionality. + """ + + def __init__( + self, + gpu_id: int, + od_config: OmniDiffusionConfig, + base_worker_class: type = DiffusionWorker, + worker_extension_cls: str | None = None, + custom_pipeline_args: dict[str, Any] | None = None, + ): + """ + Initialize WorkerWrapperBase with support for worker extensions. + + Args: + gpu_id: GPU device ID + od_config: OmniDiffusionConfig configuration + worker_extension_cls: Optional qualified name of worker extension class + custom_pipeline_args: Optional arguments for custom pipeline initialization + """ + self.gpu_id = gpu_id + self.od_config = od_config + self.base_worker_class = base_worker_class + self.worker_extension_cls = worker_extension_cls + self.custom_pipeline_args = custom_pipeline_args + + # Prepare worker class with extension support + worker_class = self._prepare_worker_class() + + # Create the actual worker instance + # When custom_pipeline_args is provided, skip initial model loading + # since re_init_pipeline will handle it. This avoids allocating memory + # through CuMemAllocator twice, which causes assertion failures in + # sleep mode. + self.worker = worker_class( + local_rank=gpu_id, + rank=gpu_id, + od_config=od_config, + skip_load_model=(self.custom_pipeline_args is not None), + ) + + # Re-initialize pipeline with custom pipeline if provided + if self.custom_pipeline_args is not None: + self.worker.re_init_pipeline(self.custom_pipeline_args) + + def _prepare_worker_class(self) -> type: + """ + Prepare the worker class with optional extension. + Dynamically extends GPUWorker with worker_extension_cls if provided. + + Returns: + The worker class (potentially extended) + """ + worker_class = self.base_worker_class + + # If custom_pipeline_args is provided, use CustomPipelineWorkerExtension + if self.custom_pipeline_args is not None: + # Set worker_extension_cls to CustomPipelineWorkerExtension if not already set + if self.worker_extension_cls is None: + self.worker_extension_cls = CustomPipelineWorkerExtension + + if self.worker_extension_cls: + if isinstance(self.worker_extension_cls, str): + worker_extension_cls = resolve_obj_by_qualname(self.worker_extension_cls) + else: + worker_extension_cls = self.worker_extension_cls + extended_calls = [] + + if worker_extension_cls not in worker_class.__bases__: + # Check for conflicts between worker and extension + for attr in dir(worker_extension_cls): + if attr.startswith("__"): + continue + if hasattr(worker_class, attr): + logger.warning( + f"Worker class {worker_class} already has attribute " + f"{attr}, which may conflict with worker extension " + f"class {worker_extension_cls}." + ) + if callable(getattr(worker_extension_cls, attr)): + extended_calls.append(attr) + + # Dynamically inherit the worker extension class + class_name = f"{worker_class.__name__}With{worker_extension_cls.__name__}" + worker_class = type(class_name, (worker_extension_cls, worker_class), {}) + logger.info( + "Created extended worker class %s from %s for extended calls %s", + class_name, + worker_extension_cls, + extended_calls, + ) + + return worker_class + + def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: + """ + Generate output for the given requests. + + Args: + requests: List of diffusion requests + + Returns: + DiffusionOutput with generated results + """ + return self.worker.generate(requests) + + def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: + """ + Execute a forward pass. + + Args: + reqs: List of diffusion requests + od_config: OmniDiffusionConfig configuration + + Returns: + DiffusionOutput with generated results + """ + return self.worker.execute_model(reqs, od_config) + + def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute one diffusion step.""" + return self.worker.execute_stepwise(scheduler_output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load model weights. + + Args: + weights: Iterable of (name, tensor) tuples + + Returns: + Set of loaded weight names + """ + return self.worker.load_weights(weights) + + def sleep(self, level: int = 1) -> bool: + """ + Put the worker to sleep. The worker should not process any requests. + The caller should guarantee that no requests are being processed + during the sleep period, before `wake_up` is called. + + Args: + level: The sleep level. Level 1 sleep will offload the model + weights and discard the kv cache. + Currently only support level 1. + + Returns: + True on success + """ + return self.worker.sleep(level) + + def wake_up(self, tags: list[str] | None = None) -> bool: + """ + Wake up the worker from sleep mode. See the sleep function + method for more details. + + Args: + tags: An optional list of tags to reallocate the worker memory + for specific memory allocations. Values must be in + `("weights")`. If None, all memory is reallocated. + wake_up should be called with all tags (or None) before the + worker is used again. + + Returns: + True on success + """ + return self.worker.wake_up(tags) + + def shutdown(self) -> None: + """Shutdown the worker and cleanup resources.""" + return self.worker.shutdown() + + def execute_method(self, method: str | bytes, *args, **kwargs) -> Any: + """ + Execute a method on the worker. + + Args: + method: Method name (str) or serialized callable (bytes) + + Returns: + Result of the method execution (type depends on the method) + + Raises: + Exception: If method execution fails + """ + try: + # Method resolution order: + # 1. If method is defined in this class, it will be called directly + # 2. Otherwise, since we define `__getattr__` and redirect attribute + # query to `self.worker`, the method will be called on the worker + assert isinstance(method, str), "Method must be str" + func = getattr(self.worker, method) + return func(*args, **kwargs) + + except Exception as e: + msg = f"Error executing method {method!r}. This might cause issues in distributed execution." + logger.exception(msg) + raise e + + def __getattr__(self, attr: str): + """Delegate attribute access to the wrapped worker.""" + return getattr(self.worker, attr) From 4bc32730ea5a31462ca30873bb205017dbae25a0 Mon Sep 17 00:00:00 2001 From: kje Date: Mon, 6 Apr 2026 09:41:27 +0900 Subject: [PATCH 4/4] fix: vLLM 0.18.0 compatibility for unit tests and config imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tests/unit/conftest.py: stub vllm_omni heavy init so unit tests can import stage_input_processors without a full vLLM installation - vllm_omni/config/model.py: guard _RUNNER_TASKS / TaskOption imports with try/except fallback for vLLM 0.18.0 where these were removed Co-Authored-By: 길재은 Co-Authored-By: Hyunjoon Jeong --- tests/unit/conftest.py | 18 +++++++++++++++++ vllm_omni/config/model.py | 41 ++++++++++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 11 deletions(-) create mode 100644 tests/unit/conftest.py diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000000..69db81f04ff --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,18 @@ +"""conftest.py for unit tests — stubs out heavy vllm_omni init.""" +import sys +import types + +# Provide a lightweight stub for vllm_omni so that submodule imports +# (e.g. vllm_omni.model_executor.stage_input_processors) don't trigger the +# full package __init__.py which requires a complete vLLM installation. +_stub = types.ModuleType("vllm_omni") +_stub.__path__ = [] +_stub.__spec__ = None +sys.modules.setdefault("vllm_omni", _stub) + +# Stub out vllm_omni.inputs.data.OmniTokensPrompt +_inputs = types.ModuleType("vllm_omni.inputs") +_inputs_data = types.ModuleType("vllm_omni.inputs.data") +_inputs_data.OmniTokensPrompt = dict # type: ignore[attr-defined] +sys.modules.setdefault("vllm_omni.inputs", _inputs) +sys.modules.setdefault("vllm_omni.inputs.data", _inputs_data) diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index e074689c9e2..1fdecfba3ff 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -6,18 +6,37 @@ import vllm.envs as envs from pydantic import ConfigDict from pydantic.dataclasses import dataclass -from vllm.attention.backends.registry import AttentionBackendEnum +try: + from vllm.attention.backends.registry import AttentionBackendEnum +except ImportError: + from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.config import ModelConfig, config -from vllm.config.model import ( - _RUNNER_CONVERTS, - _RUNNER_TASKS, - ConvertOption, - ConvertType, - RunnerOption, - TaskOption, - _get_and_verify_dtype, - get_served_model_name, -) +try: + from vllm.config.model import ( + _RUNNER_CONVERTS, + _RUNNER_TASKS, + ConvertOption, + ConvertType, + RunnerOption, + TaskOption, + _get_and_verify_dtype, + get_served_model_name, + ) +except ImportError: + # vLLM 0.18.0: _RUNNER_TASKS and TaskOption were removed/renamed + _RUNNER_TASKS: dict = { + "generate": {"generate", "auto"}, + "pooling": {"embed", "classify", "reward", "score"}, + } + from vllm.config.model import ( # type: ignore[no-redef] + _RUNNER_CONVERTS, + ConvertOption, + ConvertType, + RunnerOption, + _get_and_verify_dtype, + get_served_model_name, + ) + TaskOption = str # type: ignore[misc,assignment] from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.pooler import PoolerConfig from vllm.logger import init_logger