From 82a8792097a25725cdc417f8611806e2b4f936c4 Mon Sep 17 00:00:00 2001 From: indevn Date: Fri, 15 May 2026 14:30:47 +0800 Subject: [PATCH] Enable Qwen3-Omni realtime async-chunk bridge Signed-off-by: indevn --- examples/online_serving/qwen3_omni/README.md | 18 +- .../test_qwen3_omni_realtime_websocket.py | 34 +- .../test_realtime_connection_helpers.py | 235 +++++++++++- vllm_omni/entrypoints/async_omni.py | 13 +- vllm_omni/entrypoints/openai/__init__.py | 29 +- vllm_omni/entrypoints/openai/api_server.py | 16 - .../entrypoints/openai/realtime_connection.py | 356 ++++++++++++++---- vllm_omni/entrypoints/streaming_input.py | 18 + 8 files changed, 594 insertions(+), 125 deletions(-) create mode 100644 vllm_omni/entrypoints/streaming_input.py diff --git a/examples/online_serving/qwen3_omni/README.md b/examples/online_serving/qwen3_omni/README.md index a012cf21f9e..c451d4c1bfe 100644 --- a/examples/online_serving/qwen3_omni/README.md +++ b/examples/online_serving/qwen3_omni/README.md @@ -18,9 +18,7 @@ Asynchronous chunk streaming operates as **enabled by default** within this bund Additionally, NPU, ROCm, and XPU per-platform configuration deltas are deterministically merged from the `platforms`: section of the corresponding YAML. -**Note:** The OpenAI-style **`/v1/realtime`** WebSocket interface (facilitating streaming PCM audio input alongside audio and transcription output) -is currently **unsupported** while the `async_chunk` configuration attribute is enabled. -It is requisite to instantiate the default omni architecture or utilize a deployment configuration specifying `async_chunk: false` to facilitate real-time streaming sessions. +The OpenAI-style **`/v1/realtime`** WebSocket interface accepts streaming PCM audio input and returns audio plus transcription events. With `async_chunk: true`, realtime sessions use a commit-then-generate bridge: the server buffers uploaded PCM chunks until `input_audio_buffer.commit` with `final: true`, then submits one normal multimodal Qwen3-Omni request through the async-chunk Thinker -> Talker -> Code2Wav pipeline. Use `--no-async-chunk` only when you specifically want the legacy streaming-input path where generation can start from a non-final commit. To explicitly utilize a custom deployment YAML, mandate the configuration path accordingly: ```bash @@ -141,7 +139,9 @@ parser defaults). If you don't pass a flag, the YAML value wins. > bool. Pipelines that implement alternate processor functions for > chunked vs end-to-end modes (e.g. qwen3_tts code2wav) dispatch > automatically based on that bool — no extra flag or variant yaml is -> needed. +> needed. For `/v1/realtime`, `async_chunk: true` waits for a final commit +> before generation, while `--no-async-chunk` preserves the legacy +> streaming-input behavior. > ⚠️ **For multi-stage models that share GPUs (qwen3_omni_moe by default > shares cuda:1 between stages 1 and 2), avoid using global memory flags.** @@ -255,7 +255,7 @@ python examples/online_serving/openai_chat_completion_client_for_multimodal_gene [`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, streams a local WAV as **PCM16 mono @ 16 kHz** in fixed-size chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and receives **`response.audio.delta`** (incremental PCM for the reply) plus **`transcription.*`** events. By default it concatenates audio deltas and writes **`--output-wav`** (model output is typically **24 kHz**). Optional **`--delta-dump-dir`** saves each delta as `delta_000001.wav`, … for debugging. -Streaming input works well for translation-style use cases; if the Thinker runs while input is still incomplete, consider limiting **`max_tokens`** in your session / server defaults to avoid over-generation. +Streaming input works well for translation-style use cases. In `async_chunk: true` mode, the client can still upload chunks incrementally, but generation starts after the final commit. In `--no-async-chunk` mode, a non-final commit starts the legacy streaming-input path, so consider limiting **`max_tokens`** in your session / server defaults if the Thinker runs while input is still incomplete. **Dependencies:** @@ -289,12 +289,18 @@ python openai_realtime_client.py \ | `--num-requests` | `1` | Number of sequential sessions (see `--concurrency`) | | `--concurrency` | `1` | Max concurrent WebSocket sessions when `--num-requests` > 1 | -Ensure the server is running **without** `async_chunk` if you use `/v1/realtime`, for example: +The default Qwen3-Omni deployment enables `async_chunk`, so the command below uses the commit-then-generate bridge: ```bash vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 ``` +To use the legacy realtime streaming-input path instead, disable async chunking explicitly: + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --no-async-chunk +``` + The Python client supports the following command-line arguments: - `--query-type` (or `-q`): Query type (default: `use_video`). Options: `text`, `use_audio`, `use_image`, `use_video` diff --git a/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py b/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py index 00a6565d926..8cde7f312a1 100644 --- a/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py +++ b/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py @@ -8,6 +8,7 @@ import asyncio import base64 +import inspect import io import json import os @@ -23,7 +24,7 @@ generate_synthetic_audio, ) from tests.helpers.runtime import OmniServerParams -from tests.helpers.stage_config import get_deploy_config_path +from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -37,6 +38,7 @@ # The new-schema CI overlay bakes in async_chunk: False and covers CUDA/ROCm/XPU # via its ``platforms:`` section, so one path serves all three. default_stage_config = get_deploy_config_path("ci/qwen3_omni_moe.yaml") +async_chunk_stage_config = modify_stage_config(default_stage_config, updates={"async_chunk": True}) realtime_server_params = [ pytest.param( @@ -46,7 +48,15 @@ use_stage_cli=True, server_args=["--no-async-chunk"], ), - id="default", + id="no_async_chunk", + ), + pytest.param( + OmniServerParams( + model=MODEL, + stage_config_path=async_chunk_stage_config, + use_stage_cli=True, + ), + id="async_chunk", ), ] @@ -74,6 +84,13 @@ def _wav_bytes_from_pcm16(pcm: bytes, sample_rate_hz: int) -> bytes: return buf.getvalue() +async def _connect_local_websocket(uri: str): + kwargs: dict = {"max_size": 64 * 1024 * 1024} + if "proxy" in inspect.signature(websockets.connect).parameters: + kwargs["proxy"] = None + return await websockets.connect(uri, **kwargs) + + async def _run_realtime_audio_roundtrip( host: str, port: int, @@ -92,7 +109,18 @@ async def _run_realtime_audio_roundtrip( bytes_per_ms = 16000 * 2 // 1000 chunk_bytes = max(bytes_per_ms * chunk_ms, 2) - async with websockets.connect(uri, max_size=64 * 1024 * 1024) as ws: + last_connect_error: Exception | None = None + for attempt in range(5): + try: + ws = await _connect_local_websocket(uri) + break + except (OSError, websockets.exceptions.InvalidMessage) as exc: + last_connect_error = exc + await asyncio.sleep(1 + attempt) + else: + raise AssertionError(f"Could not connect to realtime websocket: {last_connect_error!r}") + + async with ws: await ws.send(json.dumps({"type": "session.update", "model": model})) await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": False})) diff --git a/tests/entrypoints/test_realtime_connection_helpers.py b/tests/entrypoints/test_realtime_connection_helpers.py index e795aa92d0f..fd4d0e0df0e 100644 --- a/tests/entrypoints/test_realtime_connection_helpers.py +++ b/tests/entrypoints/test_realtime_connection_helpers.py @@ -4,15 +4,21 @@ from __future__ import annotations +import asyncio import base64 +import importlib +import json +import sys +from types import SimpleNamespace +from typing import Any import numpy as np import pytest import torch from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection +from vllm_omni.entrypoints.streaming_input import validate_streaming_input_sampling_params pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -57,6 +63,223 @@ def test_pcm16_b64_roundtrip(self) -> None: assert pcm[2] == -32767 +class _FakeModel: + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str: + assert modality == "audio" + assert i == 0 + return "" + + +class _FakeWebSocket: + def __init__(self) -> None: + self.messages: list[str] = [] + + async def send_text(self, payload: str) -> None: + self.messages.append(payload) + + +class _FakeAsyncChunkEngine: + async_chunk = True + + def __init__(self) -> None: + self.default_sampling_params_list = [ + SamplingParams( + n=1, + output_kind=RequestOutputKind.CUMULATIVE, + ) + ] + self.generate_kwargs: dict[str, Any] | None = None + self.abort_calls: list[str] = [] + + def generate(self, **kwargs): + self.generate_kwargs = kwargs + + async def _results(): + yield SimpleNamespace( + prompt_token_ids=[11, 12], + outputs=[ + SimpleNamespace( + token_ids=[21, 22], + text="hello", + ) + ], + multimodal_output={ + "audio": np.array([0.0, 0.5], dtype=np.float32), + "sample_rate": 24000, + }, + ) + + return _results() + + async def abort(self, request_id: str) -> None: + self.abort_calls.append(request_id) + + +class _FakeServing: + model_cls = _FakeModel + model_config = None + + +class _FakePrivateEngineServing: + model_cls = _FakeModel + model_config = None + + def __init__(self, engine: _FakeAsyncChunkEngine) -> None: + self._engine_client = engine + + +class TestRealtimeConnectionAsyncChunkBridge: + def _connection(self) -> tuple[RealtimeConnection, _FakeAsyncChunkEngine, _FakeWebSocket]: + conn = RealtimeConnection.__new__(RealtimeConnection) + engine = _FakeAsyncChunkEngine() + websocket = _FakeWebSocket() + conn.engine = engine + conn.serving = _FakeServing() + conn.websocket = websocket + conn.audio_queue = asyncio.Queue() + conn.connection_id = "test-conn" + conn.generation_task = None + conn._is_connected = True + conn._is_model_validated = True + conn._max_audio_filesize_mb = 64 + conn._realtime_audio_ref = None + return conn, engine, websocket + + def test_init_binds_engine_without_runtime_async_omni_import(self) -> None: + engine = _FakeAsyncChunkEngine() + serving = SimpleNamespace(engine_client=engine) + + conn = RealtimeConnection(_FakeWebSocket(), serving) + + assert conn.engine is engine + assert conn._realtime_audio_ref is None + + def test_init_accepts_private_upstream_engine_client_attr(self) -> None: + engine = _FakeAsyncChunkEngine() + serving = _FakePrivateEngineServing(engine) + + conn = RealtimeConnection(_FakeWebSocket(), serving) + + assert conn.engine is engine + + def test_build_realtime_audio_prompt_uses_audio_tuple(self) -> None: + audio = np.array([[0.1], [-0.2]], dtype=np.float64) + prompt = RealtimeConnection._build_realtime_audio_prompt(audio, 16000, _FakeModel) + + assert prompt["prompt"] == "<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n" + prompt_audio, sample_rate = prompt["multi_modal_data"]["audio"] + assert sample_rate == 16000 + assert prompt_audio.dtype == np.float32 + assert prompt_audio.shape == (2,) + np.testing.assert_allclose(prompt_audio, [0.1, -0.2], rtol=1e-6) + + @pytest.mark.asyncio + async def test_collect_committed_audio_concatenates_until_sentinel(self) -> None: + conn, _, _ = self._connection() + conn.audio_queue.put_nowait(np.array([0.1, 0.2], dtype=np.float32)) + conn.audio_queue.put_nowait(np.array([[0.3]], dtype=np.float32)) + conn.audio_queue.put_nowait(None) + + audio = await conn._collect_committed_audio() + + assert audio.dtype == np.float32 + np.testing.assert_allclose(audio, [0.1, 0.2, 0.3], rtol=1e-6) + + @pytest.mark.asyncio + async def test_collect_committed_audio_rejects_empty_commit(self) -> None: + conn, _, _ = self._connection() + conn.audio_queue.put_nowait(None) + + with pytest.raises(ValueError, match="No audio data"): + await conn._collect_committed_audio() + + @pytest.mark.asyncio + async def test_non_final_commit_does_not_start_async_chunk_generation(self) -> None: + conn, _, _ = self._connection() + called = False + + async def _fake_start() -> None: + nonlocal called + called = True + + conn._start_async_chunk_bridge_generation = _fake_start + + await conn.handle_event({"type": "input_audio_buffer.commit", "final": False}) + + assert called is False + assert conn.audio_queue.empty() + + @pytest.mark.asyncio + async def test_final_commit_starts_async_chunk_generation(self) -> None: + conn, _, _ = self._connection() + called = False + + async def _fake_start() -> None: + nonlocal called + called = True + + conn._start_async_chunk_bridge_generation = _fake_start + + await conn.handle_event({"type": "input_audio_buffer.commit", "final": True}) + + assert called is True + + @pytest.mark.asyncio + async def test_async_chunk_generation_passes_single_prompt_to_engine(self) -> None: + conn, engine, websocket = self._connection() + conn.audio_queue.put_nowait(np.array([0.25, -0.25], dtype=np.float32)) + conn.audio_queue.put_nowait(None) + + await conn._run_async_chunk_bridge_generation(asyncio.Queue()) + + assert engine.generate_kwargs is not None + prompt = engine.generate_kwargs["prompt"] + assert not hasattr(prompt, "__aiter__") + assert prompt["prompt"] == "<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n" + prompt_audio, sample_rate = prompt["multi_modal_data"]["audio"] + assert sample_rate == 16000 + np.testing.assert_allclose(prompt_audio, [0.25, -0.25], rtol=1e-6) + + sampling_params_list = engine.generate_kwargs["sampling_params_list"] + assert sampling_params_list[0].output_kind == RequestOutputKind.DELTA + + events = [json.loads(message) for message in websocket.messages] + event_types = [event["type"] for event in events] + assert event_types == [ + "transcription.delta", + "response.audio.delta", + "transcription.done", + "response.audio.done", + ] + assert events[0]["delta"] == "hello" + assert events[1]["sample_rate_hz"] == 24000 + assert events[-1]["has_audio"] is True + + def test_realtime_sampling_params_do_not_mutate_engine_defaults(self) -> None: + conn, engine, _ = self._connection() + default_params = engine.default_sampling_params_list[0] + default_params.skip_clone = True + default_params.output_kind = RequestOutputKind.CUMULATIVE + + sampling_params_list = conn._realtime_sampling_params_list() + + assert sampling_params_list[0] is not default_params + assert sampling_params_list[0].output_kind == RequestOutputKind.DELTA + assert default_params.output_kind == RequestOutputKind.CUMULATIVE + + +class TestOpenAIEntrypointExports: + def test_serving_chat_export_does_not_eager_import_api_server(self) -> None: + sys.modules.pop("vllm_omni.entrypoints.openai.api_server", None) + + openai_entrypoints = importlib.import_module("vllm_omni.entrypoints.openai") + + assert "vllm_omni.entrypoints.openai.api_server" not in sys.modules + assert openai_entrypoints.OmniOpenAIServingChat.__name__ == "OmniOpenAIServingChat" + assert "vllm_omni.entrypoints.openai.api_server" not in sys.modules + + class TestAsyncOmniStreamingParamsValidation: def test_accepts_streaming_friendly_params(self) -> None: p = SamplingParams( @@ -64,23 +287,23 @@ def test_accepts_streaming_friendly_params(self) -> None: stop=[], output_kind=RequestOutputKind.DELTA, ) - AsyncOmni._validate_streaming_input_sampling_params(p) + validate_streaming_input_sampling_params(p) def test_rejects_non_sampling_params(self) -> None: with pytest.raises(ValueError, match="Input streaming"): - AsyncOmni._validate_streaming_input_sampling_params(object()) # type: ignore[arg-type] + validate_streaming_input_sampling_params(object()) # type: ignore[arg-type] def test_rejects_n_greater_than_one(self) -> None: p = SamplingParams(n=2, stop=[], output_kind=RequestOutputKind.DELTA) with pytest.raises(ValueError, match="Input streaming"): - AsyncOmni._validate_streaming_input_sampling_params(p) + validate_streaming_input_sampling_params(p) def test_rejects_final_only(self) -> None: p = SamplingParams(n=1, stop=[], output_kind=RequestOutputKind.FINAL_ONLY) with pytest.raises(ValueError, match="Input streaming"): - AsyncOmni._validate_streaming_input_sampling_params(p) + validate_streaming_input_sampling_params(p) def test_rejects_stop_strings(self) -> None: p = SamplingParams(n=1, stop=["\n"], output_kind=RequestOutputKind.DELTA) with pytest.raises(ValueError, match="Input streaming"): - AsyncOmni._validate_streaming_input_sampling_params(p) + validate_streaming_input_sampling_params(p) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index c22e780962a..ba2bff8b0ed 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -21,7 +21,6 @@ from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.renderers.inputs.preprocess import extract_prompt_components -from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tasks import SupportedTask from vllm.v1.engine.exceptions import EngineDeadError @@ -32,6 +31,7 @@ OmniBase, OmniEngineDeadError, ) +from vllm_omni.entrypoints.streaming_input import validate_streaming_input_sampling_params from vllm_omni.inputs.data import OmniSamplingParams from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.outputs import OmniRequestOutput @@ -456,16 +456,7 @@ async def handle_inputs() -> None: @staticmethod def _validate_streaming_input_sampling_params(params: OmniSamplingParams) -> None: - if ( - not isinstance(params, SamplingParams) - or params.n > 1 - or params.output_kind == RequestOutputKind.FINAL_ONLY - or params.stop - ): - raise ValueError( - "Input streaming is currently supported only for SamplingParams " - "with n == 1, output_kind != FINAL_ONLY, and without stop strings." - ) + validate_streaming_input_sampling_params(params) async def encode( self, diff --git a/vllm_omni/entrypoints/openai/__init__.py b/vllm_omni/entrypoints/openai/__init__.py index e27cb238c54..7f65ed7133c 100644 --- a/vllm_omni/entrypoints/openai/__init__.py +++ b/vllm_omni/entrypoints/openai/__init__.py @@ -1,26 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -OpenAI-compatible API entrypoints for vLLM-Omni. - -Provides: -- omni_run_server: Main server entry point (auto-detects model type) -- OmniOpenAIServingChat: Unified chat completion handler for both LLM and diffusion models -""" - -from vllm_omni.entrypoints.openai.api_server import ( - build_async_omni, - omni_init_app_state, - omni_run_server, -) -from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat - __all__ = [ - # Server functions "omni_run_server", "build_async_omni", "omni_init_app_state", - # Serving classes "OmniOpenAIServingChat", ] + + +def __getattr__(name: str): + if name in {"omni_run_server", "build_async_omni", "omni_init_app_state"}: + from vllm_omni.entrypoints.openai import api_server + + return getattr(api_server, name) + if name == "OmniOpenAIServingChat": + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + return OmniOpenAIServingChat + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index c1467f7190a..214cb46036a 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -1378,22 +1378,6 @@ async def streaming_video_chat(websocket: WebSocket): @router.websocket("/v1/realtime") async def realtime_websocket(websocket: WebSocket): """WebSocket endpoint for OpenAI-style realtime interactions.""" - engine_client = getattr(websocket.app.state, "engine_client", None) - if engine_client is not None and getattr(engine_client, "async_chunk", False): - await websocket.accept() - await websocket.send_json( - { - "type": "error", - "error": ( - "The /v1/realtime API is not supported when async_chunk is enabled on the server. " - "Use a stage configuration with async_chunk disabled and restart the server before using " - "this endpoint." - ), - "code": "unsupported", - } - ) - await websocket.close() - return serving = getattr(websocket.app.state, "openai_serving_realtime", None) if serving is None: await websocket.accept() diff --git a/vllm_omni/entrypoints/openai/realtime_connection.py b/vllm_omni/entrypoints/openai/realtime_connection.py index 9fc2a1ee3a0..e604d8d7256 100644 --- a/vllm_omni/entrypoints/openai/realtime_connection.py +++ b/vllm_omni/entrypoints/openai/realtime_connection.py @@ -4,18 +4,25 @@ import base64 import json from collections.abc import AsyncGenerator -from typing import cast +from typing import TYPE_CHECKING, Any, cast from uuid import uuid4 import numpy as np from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.realtime.connection import RealtimeConnection as VllmRealtimeConnection -from vllm.entrypoints.openai.realtime.protocol import TranscriptionDelta, TranscriptionDone +from vllm.entrypoints.openai.realtime.protocol import ( + InputAudioBufferCommit, + TranscriptionDelta, + TranscriptionDone, +) from vllm.logger import init_logger -from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.openai.stage_params import clone_sampling_params from vllm_omni.entrypoints.utils import coerce_param_message_types +if TYPE_CHECKING: + from vllm_omni.entrypoints.async_omni import AsyncOmni + logger = init_logger(__name__) @@ -28,12 +35,81 @@ class RealtimeConnection(VllmRealtimeConnection): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.engine = cast(AsyncOmni, self.serving.engine_client) + self.engine = cast("AsyncOmni", self._get_serving_engine_client()) self._realtime_audio_ref: np.ndarray | None = None + def _get_serving_engine_client(self): + engine_client = getattr(self.serving, "engine_client", None) + if engine_client is None: + engine_client = getattr(self.serving, "_engine_client", None) + if engine_client is None: + raise ValueError("Realtime serving object does not expose an engine client.") + return engine_client + async def start_generation(self): + if self._uses_async_chunk_bridge: + # In async_chunk mode the bridge must wait for final=True so stage 0 + # receives a normal multimodal request instead of streaming updates. + logger.debug( + "Ignoring non-final realtime commit for async_chunk bridge: %s", + self.connection_id, + ) + return await super().start_generation() + @property + def _uses_async_chunk_bridge(self) -> bool: + return bool(getattr(self.engine, "async_chunk", False)) + + async def handle_event(self, event: dict): + if not self._uses_async_chunk_bridge: + await super().handle_event(event) + return + + if event.get("type") == "input_audio_buffer.append" and self._generation_in_progress(): + await self.send_error("Generation already in progress", "generation_in_progress") + return + + if event.get("type") != "input_audio_buffer.commit": + await super().handle_event(event) + return + + if not self._is_model_validated: + err_msg = ( + "Model not validated. Make sure to validate the" + " model by sending a session.update event." + ) + await self.send_error(err_msg, "model_not_validated") + return + + commit_event = InputAudioBufferCommit(**event) + if not commit_event.final: + logger.debug( + "Received non-final realtime commit in async_chunk bridge mode: %s", + self.connection_id, + ) + return + + await self._start_async_chunk_bridge_generation() + + async def _start_async_chunk_bridge_generation(self) -> None: + if self._generation_in_progress(): + logger.warning( + "Generation already in progress, ignoring final commit: %s", + self.connection_id, + ) + await self.send_error("Generation already in progress", "generation_in_progress") + return + + self.audio_queue.put_nowait(None) + input_stream = asyncio.Queue[list[int]]() + self.generation_task = asyncio.create_task( + self._run_async_chunk_bridge_generation(input_stream), + ) + + def _generation_in_progress(self) -> bool: + return self.generation_task is not None and not self.generation_task.done() + @staticmethod def _tensor_to_numpy(value) -> np.ndarray | None: if value is None: @@ -109,95 +185,243 @@ def _pcm16_b64(audio_f32: np.ndarray) -> str: pcm16 = (clipped * 32767.0).astype(np.int16) return base64.b64encode(pcm16.tobytes()).decode("utf-8") - async def _run_generation( + async def _collect_committed_audio(self) -> np.ndarray: + chunks: list[np.ndarray] = [] + total_pcm16_nbytes = 0 + + while True: + audio_chunk = await self.audio_queue.get() + if audio_chunk is None: + break + + arr = self._tensor_to_numpy(audio_chunk) + if arr is None or arr.size == 0: + continue + + arr = np.ascontiguousarray(arr, dtype=np.float32) + chunks.append(arr) + total_pcm16_nbytes += arr.size * np.dtype(np.int16).itemsize + + max_mb = getattr(self, "_max_audio_filesize_mb", None) + if max_mb is not None and total_pcm16_nbytes / 1024**2 > max_mb: + raise ValueError("Maximum file size exceeded") + + if not chunks: + raise ValueError("No audio data received before final commit.") + + return np.concatenate(chunks).astype(np.float32, copy=False) + + def _get_realtime_input_sample_rate(self) -> int: + model_config = getattr(self.serving, "model_config", None) + if model_config is None: + return 16000 + + try: + from vllm.transformers_utils.processor import cached_processor_from_config + + processor = cached_processor_from_config(model_config) + feature_extractor = getattr(processor, "feature_extractor", None) + sample_rate = getattr(feature_extractor, "sampling_rate", None) + if sample_rate: + return int(sample_rate) + except Exception: + logger.debug( + "Failed to resolve realtime input sample rate from processor; " + "falling back to 16 kHz.", + exc_info=True, + ) + + return 16000 + + @staticmethod + def _audio_placeholder_from_model_cls(model_cls: Any) -> str: + get_placeholder_str = getattr(model_cls, "get_placeholder_str", None) + if get_placeholder_str is not None: + placeholder = get_placeholder_str("audio", 0) + if placeholder: + return str(placeholder) + return "<|audio_start|><|audio_pad|><|audio_end|>" + + @classmethod + def _build_realtime_audio_prompt( + cls, + audio: np.ndarray, + sample_rate: int, + model_cls: Any = None, + ) -> dict[str, Any]: + audio = np.ascontiguousarray(audio, dtype=np.float32) + if audio.ndim > 1: + audio = audio.reshape(-1) + if audio.size == 0: + raise ValueError("No audio data received before final commit.") + + audio_placeholder = cls._audio_placeholder_from_model_cls(model_cls) + prompt = f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n<|im_start|>assistant\n" + return { + "prompt": prompt, + "multi_modal_data": {"audio": (audio, int(sample_rate))}, + } + + def _realtime_sampling_params_list(self): + sampling_params_list = [ + clone_sampling_params(params) + for params in self.engine.default_sampling_params_list + ] + return coerce_param_message_types( + sampling_params_list, + is_streaming=True, + ) + + async def _abort_generation_request(self, request_id: str) -> None: + abort = getattr(self.engine, "abort", None) + if abort is None: + return + try: + await abort(request_id) + except Exception: + logger.exception("Failed to abort realtime request: %s", request_id) + + async def _consume_generation_outputs( self, - streaming_input_gen: AsyncGenerator, + result_gen, input_stream: asyncio.Queue[list[int]], - ): - request_id = f"rt-{self.connection_id}-{uuid4()}" + ) -> tuple[bool, bool]: sent_audio = False - audio_done_sent = False full_text = "" prompt_token_ids_len = 0 completion_tokens_len = 0 - self._realtime_audio_ref = None - # Coerce cumulative outputs to delta outputs; this ensures - # we don't emit redundant MM data & drain after emitting. - sampling_params_list = list(self.engine.default_sampling_params_list) - sampling_params_list = coerce_param_message_types( - sampling_params_list, - is_streaming=True, + async for output in result_gen: + if output.outputs and len(output.outputs) > 0: + first_output = output.outputs[0] + new_token_ids = list(first_output.token_ids) + new_tokens_len = len(new_token_ids) + + if not prompt_token_ids_len and output.prompt_token_ids: + prompt_token_ids_len = len(output.prompt_token_ids) + + if new_tokens_len: + input_stream.put_nowait(new_token_ids) + + delta_text = first_output.text or "" + full_text += delta_text + + if delta_text: + await self.send(TranscriptionDelta(delta=delta_text)) + + completion_tokens_len += new_tokens_len + + audio_chunks, sample_rate = self._extract_audio_chunks(output) + + for chunk in audio_chunks: + sent_audio = True + await self.send_json( + { + "type": "response.audio.delta", + "audio": self._pcm16_b64(chunk), + "format": "pcm16", + "sample_rate_hz": sample_rate, + } + ) + + if not self._is_connected: + return sent_audio, False + + usage = UsageInfo( + prompt_tokens=prompt_token_ids_len, + completion_tokens=completion_tokens_len, + total_tokens=prompt_token_ids_len + completion_tokens_len, ) + await self.send(TranscriptionDone(text=full_text, usage=usage)) + return sent_audio, True + + def _drain_audio_queue(self) -> None: + while not self.audio_queue.empty(): + self.audio_queue.get_nowait() + + async def _run_async_chunk_bridge_generation( + self, + input_stream: asyncio.Queue[list[int]], + ) -> None: + request_id = f"rt-{self.connection_id}-{uuid4()}" + sent_audio = False + audio_done_sent = False + completed = False + engine_request_started = False + self._realtime_audio_ref = None try: + audio = await self._collect_committed_audio() + prompt = self._build_realtime_audio_prompt( + audio, + self._get_realtime_input_sample_rate(), + getattr(self.serving, "model_cls", None), + ) + result_gen = self.engine.generate( - prompt=streaming_input_gen, + prompt=prompt, request_id=request_id, - sampling_params_list=sampling_params_list, + sampling_params_list=self._realtime_sampling_params_list(), ) + engine_request_started = True + sent_audio, completed = await self._consume_generation_outputs(result_gen, input_stream) + + if completed and sent_audio: + await self.send_json({"type": "response.audio.done", "has_audio": True}) + audio_done_sent = True + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception("Error in async_chunk bridge generation: %s", e) + await self.send_error(str(e), "processing_error") + finally: + if engine_request_started and not completed: + await self._abort_generation_request(request_id) + if self._is_connected and not audio_done_sent: + try: + await self.send_json({"type": "response.audio.done", "has_audio": sent_audio}) + except Exception: + logger.exception("Failed to send response.audio.done") + self._drain_audio_queue() + + async def _run_generation( + self, + streaming_input_gen: AsyncGenerator, + input_stream: asyncio.Queue[list[int]], + ): + request_id = f"rt-{self.connection_id}-{uuid4()}" + sent_audio = False + audio_done_sent = False + completed = False + engine_request_started = False + self._realtime_audio_ref = None - async for output in result_gen: - # Handle delta texts; this is very similar to the client from vLLM - if output.outputs and len(output.outputs) > 0: - first_output = output.outputs[0] - new_token_ids = list(first_output.token_ids) - new_tokens_len = len(new_token_ids) - - if not prompt_token_ids_len and output.prompt_token_ids: - prompt_token_ids_len = len(output.prompt_token_ids) - - if new_tokens_len: - input_stream.put_nowait(new_token_ids) - - delta_text = first_output.text or "" - full_text += delta_text - - # append output to input if there was any delta text - if delta_text: - await self.send(TranscriptionDelta(delta=delta_text)) - - completion_tokens_len += new_tokens_len - - # Handle audio chunking; this is Omni specific - audio_chunks, sample_rate = self._extract_audio_chunks(output) - - for chunk in audio_chunks: - sent_audio = True - await self.send_json( - { - "type": "response.audio.delta", - "audio": self._pcm16_b64(chunk), - "format": "pcm16", - "sample_rate_hz": sample_rate, - } - ) - - if not self._is_connected: - break - - usage = UsageInfo( - prompt_tokens=prompt_token_ids_len, - completion_tokens=completion_tokens_len, - total_tokens=prompt_token_ids_len + completion_tokens_len, + try: + result_gen = self.engine.generate( + prompt=streaming_input_gen, + request_id=request_id, + sampling_params_list=self._realtime_sampling_params_list(), ) - await self.send(TranscriptionDone(text=full_text, usage=usage)) + engine_request_started = True + sent_audio, completed = await self._consume_generation_outputs(result_gen, input_stream) - if sent_audio: + if completed and sent_audio: await self.send_json({"type": "response.audio.done", "has_audio": True}) audio_done_sent = True + except asyncio.CancelledError: + raise except Exception as e: logger.exception("Error in generation: %s", e) await self.send_error(str(e), "processing_error") finally: - # Always send terminal event so clients don't hang forever. + if engine_request_started and not completed: + await self._abort_generation_request(request_id) if self._is_connected and not audio_done_sent: try: await self.send_json({"type": "response.audio.done", "has_audio": sent_audio}) except Exception: logger.exception("Failed to send response.audio.done") - while not self.audio_queue.empty(): - self.audio_queue.get_nowait() + self._drain_audio_queue() async def send_json(self, payload: dict): await self.websocket.send_text(json.dumps(payload)) diff --git a/vllm_omni/entrypoints/streaming_input.py b/vllm_omni/entrypoints/streaming_input.py new file mode 100644 index 00000000000..63bcb1aafed --- /dev/null +++ b/vllm_omni/entrypoints/streaming_input.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from vllm.sampling_params import RequestOutputKind, SamplingParams + +from vllm_omni.inputs.data import OmniSamplingParams + + +def validate_streaming_input_sampling_params(params: OmniSamplingParams) -> None: + if ( + not isinstance(params, SamplingParams) + or params.n > 1 + or params.output_kind == RequestOutputKind.FINAL_ONLY + or params.stop + ): + raise ValueError( + "Input streaming is currently supported only for SamplingParams " + "with n == 1, output_kind != FINAL_ONLY, and without stop strings." + )