diff --git a/examples/online_serving/qwen3_omni/README.md b/examples/online_serving/qwen3_omni/README.md index 45482984b91..c3171e43667 100644 --- a/examples/online_serving/qwen3_omni/README.md +++ b/examples/online_serving/qwen3_omni/README.md @@ -36,6 +36,45 @@ cd examples/online_serving/qwen3_omni python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py --model Qwen/Qwen3-Omni-30B-A3B-Instruct --query-type use_image --port 8091 --host "localhost" ``` +#### Realtime WebSocket client (`openai_realtime_client.py`) + +[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, uploads a local audio file as **PCM16 mono @ 16 kHz** chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and prints **streaming transcription** (`transcription.delta` / `transcription.done`). + +**Dependencies:** + +```bash +pip install websockets librosa numpy +``` + +(ffmpeg may be required by `librosa` for some formats; see the FAQ below.) + +**From this directory** (`examples/online_serving/qwen3_omni`): + +```bash +python openai_realtime_client.py \ + --host localhost \ + --port 8091 \ + --model Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --audio_path /path/to/your.wav +``` + +If `--audio_path` is omitted, the script uses a bundled default clip (`mary_had_lamb` via vLLM assets). + +**Arguments:** + +| Flag | Default | Description | +|------|---------|-------------| +| `--host` | `localhost` | API server host | +| `--port` | `8000` | API server port (match your `vllm serve` port, e.g. `8091`) | +| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (also sent in `session.update`) | +| `--audio_path` | *(optional)* | Path to input audio; resampled to 16 kHz mono inside the client | + +Ensure the vLLM-Omni server is running with realtime support for this endpoint, for example: + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 +``` + The Python client supports the following command-line arguments: - `--query-type` (or `-q`): Query type (default: `use_video`). Options: `text`, `use_audio`, `use_image`, `use_video` diff --git a/examples/online_serving/qwen3_omni/openai_realtime_client.py b/examples/online_serving/qwen3_omni/openai_realtime_client.py new file mode 100644 index 00000000000..4fa043c481d --- /dev/null +++ b/examples/online_serving/qwen3_omni/openai_realtime_client.py @@ -0,0 +1,146 @@ +""" +This script demonstrates how to use the vLLM-Omni Realtime WebSocket API to perform +audio transcription by uploading an audio file. + +Before running this script, you must start the vLLM-Omni server with a realtime-capable +model, for example: + + vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni + +Requirements: +- vllm with audio support +- websockets +- librosa +- numpy + +The script: +1. Connects to the Realtime WebSocket endpoint +2. Converts an audio file to PCM16 @ 16kHz +3. Sends audio chunks to the server +4. Receives and prints transcription as it streams +""" + +import argparse +import asyncio +import base64 +import json + +import librosa +import numpy as np +import websockets +from vllm.assets.audio import AudioAsset + + +def audio_to_pcm16_base64(audio_path: str) -> str: + """ + Load an audio file and convert it to base64-encoded PCM16 @ 16kHz. + """ + # Load audio and resample to 16kHz mono + audio, _ = librosa.load(audio_path, sr=16000, mono=True) + # Convert to PCM16 + pcm16 = (audio * 32767).astype(np.int16) + # Encode as base64 + return base64.b64encode(pcm16.tobytes()).decode("utf-8") + + +async def realtime_transcribe(audio_path: str, host: str, port: int, model: str): + """ + Connect to the Realtime API and transcribe an audio file. + """ + uri = f"ws://{host}:{port}/v1/realtime" + + async with websockets.connect(uri) as ws: + # Wait for session.created + response = json.loads(await ws.recv()) + if response["type"] == "session.created": + print(f"Session created: {response['id']}") + else: + print(f"Unexpected response: {response}") + return + + # Validate model + await ws.send(json.dumps({"type": "session.update", "model": model})) + + # Signal ready to start + await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) + + # Convert audio file to base64 PCM16 + print(f"Loading audio from: {audio_path}") + audio_base64 = audio_to_pcm16_base64(audio_path) + + # Send audio in chunks (4KB of raw audio = ~8KB base64) + chunk_size = 4096 + audio_bytes = base64.b64decode(audio_base64) + total_chunks = (len(audio_bytes) + chunk_size - 1) // chunk_size + + print(f"Sending {total_chunks} audio chunks...") + for i in range(0, len(audio_bytes), chunk_size): + chunk = audio_bytes[i : i + chunk_size] + await ws.send( + json.dumps( + { + "type": "input_audio_buffer.append", + "audio": base64.b64encode(chunk).decode("utf-8"), + } + ) + ) + + # Signal all audio is sent + await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True})) + print("Audio sent. Waiting for transcription...\n") + + # Receive transcription + print("Transcription: ", end="", flush=True) + while True: + response = json.loads(await ws.recv()) + if response["type"] == "transcription.delta": + print(response["delta"], end="", flush=True) + elif response["type"] == "transcription.done": + print(f"\n\nFinal transcription: {response['text']}") + if response.get("usage"): + print(f"Usage: {response['usage']}") + break + elif response["type"] == "error": + print(f"\nError: {response['error']}") + break + + +def main(args): + if args.audio_path: + audio_path = args.audio_path + else: + # Use default audio asset + audio_path = str(AudioAsset("mary_had_lamb").get_local_path()) + print(f"No audio path provided, using default: {audio_path}") + + asyncio.run(realtime_transcribe(audio_path, args.host, args.port, args.model)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Realtime WebSocket Transcription Client") + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-Omni-30B-A3B-Instruct", + help="Model that is served and should be pinged.", + ) + parser.add_argument( + "--audio_path", + type=str, + default=None, + help="Path to the audio file to transcribe.", + ) + parser.add_argument( + "--host", + type=str, + default="localhost", + help="vLLM-Omni server host (default: localhost)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="vLLM-Omni server port (default: 8000)", + ) + args = parser.parse_args() + main(args) diff --git a/tests/engine/test_async_omni_engine_input.py b/tests/engine/test_async_omni_engine_input.py index b2d2d9a9e53..ed6a7277b46 100644 --- a/tests/engine/test_async_omni_engine_input.py +++ b/tests/engine/test_async_omni_engine_input.py @@ -61,3 +61,31 @@ def test_build_add_request_message_preserves_additional_information(): assert request.additional_information.entries["text"].list_data == ["hello world"] assert request.additional_information.entries["speaker"].list_data == ["vivian"] output_processor.add_request.assert_called_once() + + +def test_build_add_request_message_with_resumable_streaming(): + engine = object.__new__(AsyncOmniEngine) + params = SamplingParams(max_tokens=8) + engine.default_sampling_params_list = [params] + engine.stage_metadata = [{"stage_type": "llm"}] + engine.supported_tasks = ("generate",) + + input_processor = Mock() + input_processor.process_inputs.return_value = _make_engine_core_request() + engine.input_processor = input_processor + + output_processor = Mock() + engine.output_processors = [output_processor] + + msg = engine._build_add_request_message( + request_id="req-stream", + prompt={"prompt_token_ids": [1, 2, 3]}, + sampling_params_list=[params], + final_stage_id=0, + resumable=True, + message_type="streaming_update", + ) + + assert msg["type"] == "streaming_update" + input_processor.process_inputs.assert_called_once() + assert input_processor.process_inputs.call_args.kwargs["resumable"] is True diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 9de3dc867ff..71bf6e2379f 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -635,9 +635,13 @@ def _build_add_request_message( self, request_id: str, prompt: EngineCoreRequest | PromptType, + prompt_text: str | None = None, sampling_params_list: Sequence[Any] | None = None, final_stage_id: int = 0, arrival_time: float | None = None, + *, + resumable: bool = False, + message_type: str = "add_request", ) -> dict[str, Any]: """Build an add_request message after stage-0 preprocessing.""" effective_sampling_params_list = ( @@ -669,6 +673,7 @@ def _build_add_request_message( params=params, supported_tasks=self.supported_tasks, arrival_time=arrival_time, + resumable=resumable, ) # TODO (Peiqi): add this for Qwen3-TTS only. Other models don't have # additional_information field in the prompt. @@ -683,9 +688,10 @@ def _build_add_request_message( request.external_req_id = request_id # Register with stage 0's output processor. + output_prompt_text = prompt_text self.output_processors[0].add_request( request=request, - prompt=prompt, + prompt=output_prompt_text, parent_req=None, request_index=0, queue=None, @@ -693,7 +699,7 @@ def _build_add_request_message( prompt = request return { - "type": "add_request", + "type": message_type, "request_id": request_id, "prompt": prompt, "original_prompt": original_prompt, @@ -949,9 +955,12 @@ def add_request( self, request_id: str, prompt: EngineCoreRequest | PromptType, + prompt_text: str | None = None, sampling_params_list: Sequence[Any] | None = None, final_stage_id: int = 0, arrival_time: float | None = None, + *, + resumable: bool = False, ) -> None: """Process stage-0 input locally, then send to the Orchestrator. @@ -963,9 +972,11 @@ def add_request( msg = self._build_add_request_message( request_id=request_id, prompt=prompt, + prompt_text=prompt_text, sampling_params_list=sampling_params_list, final_stage_id=final_stage_id, arrival_time=arrival_time, + resumable=resumable, ) if self.request_queue is None: raise RuntimeError("request_queue is not initialized") @@ -984,17 +995,70 @@ async def add_request_async( self, request_id: str, prompt: EngineCoreRequest | PromptType, + prompt_text: str | None = None, sampling_params_list: Sequence[Any] | None = None, final_stage_id: int = 0, arrival_time: float | None = None, + *, + resumable: bool = False, ) -> None: """Async add_request API.""" self.add_request( request_id=request_id, prompt=prompt, + prompt_text=prompt_text, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id, + arrival_time=arrival_time, + resumable=resumable, + ) + + def add_streaming_update( + self, + request_id: str, + prompt: EngineCoreRequest | PromptType, + prompt_text: str | None = None, + sampling_params_list: Sequence[Any] | None = None, + final_stage_id: int = 0, + arrival_time: float | None = None, + *, + resumable: bool = True, + ) -> None: + """Send an incremental streaming update for an existing request.""" + msg = self._build_add_request_message( + request_id=request_id, + prompt=prompt, + prompt_text=prompt_text, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id, + arrival_time=arrival_time, + resumable=resumable, + message_type="streaming_update", + ) + if self.request_queue is None: + raise RuntimeError("request_queue is not initialized") + self.request_queue.sync_q.put_nowait(msg) + + async def add_streaming_update_async( + self, + request_id: str, + prompt: EngineCoreRequest | PromptType, + prompt_text: str | None = None, + sampling_params_list: Sequence[Any] | None = None, + final_stage_id: int = 0, + arrival_time: float | None = None, + *, + resumable: bool = True, + ) -> None: + """Async wrapper for add_streaming_update().""" + self.add_streaming_update( + request_id=request_id, + prompt=prompt, + prompt_text=prompt_text, sampling_params_list=sampling_params_list, final_stage_id=final_stage_id, arrival_time=arrival_time, + resumable=resumable, ) def try_get_output(self, timeout: float = 0.001) -> dict[str, Any] | None: diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 8128c25c645..4a85a2c6c9d 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -200,6 +200,8 @@ async def _request_handler(self) -> None: if msg_type == "add_request": await self._handle_add_request(msg) + elif msg_type == "streaming_update": + await self._handle_streaming_update(msg) elif msg_type == "add_companion_request": await self._handle_add_companion(msg) elif msg_type == "abort": @@ -659,6 +661,34 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None: if self.async_chunk and stage_id == 0 and final_stage_id > 0: await self._prewarm_async_chunk_stages(request_id, request, req_state) + async def _handle_streaming_update(self, msg: dict[str, Any]) -> None: + """Handle a streaming_update message for an existing request.""" + stage_id = 0 + request_id = msg["request_id"] + request = msg["prompt"] + + req_state = self.request_states.get(request_id) + if req_state is None: + logger.warning( + "[Orchestrator] streaming_update for unknown req=%s, falling back to add_request", + request_id, + ) + fallback_msg = dict(msg) + fallback_msg["type"] = "add_request" + await self._handle_add_request(fallback_msg) + return + + if "sampling_params_list" in msg and msg["sampling_params_list"]: + req_state.sampling_params_list = msg["sampling_params_list"] + + req_state.stage_submit_ts[stage_id] = _time.time() + stage_client = self.stage_clients[stage_id] + if stage_client.stage_type == "diffusion": + params = req_state.sampling_params_list[stage_id] + await stage_client.add_request_async(request_id, request, params) + else: + await stage_client.add_request_async(request) + async def _prewarm_async_chunk_stages( self, request_id: str, diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 68c072c2b3f..6c8022461b2 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -12,12 +12,15 @@ from collections.abc import AsyncGenerator, Iterable, Sequence from typing import TYPE_CHECKING, Any -from vllm.engine.protocol import EngineClient +from vllm import TokensPrompt +from vllm.engine.protocol import EngineClient, StreamingInput from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams +from vllm.renderers.inputs.preprocess import extract_prompt_components +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tasks import SupportedTask from vllm.v1.engine.exceptions import EngineDeadError @@ -147,7 +150,8 @@ def model_config(self): async def generate( self, - prompt: OmniPromptType | list[OmniPromptType], + prompt: OmniPromptType | AsyncGenerator[StreamingInput, None] | list[OmniPromptType], + sampling_params: Any = None, request_id: str = "", *, prompt_text: str | None = None, @@ -191,6 +195,7 @@ async def generate( logger.debug(f"[AsyncOmni] generate() called for request {request_id}") + input_stream_task: asyncio.Task | None = None try: # Start final output dispatcher on the first call to generate() self._final_output_handler() @@ -214,13 +219,22 @@ async def generate( req_state.metrics = metrics self.request_states[request_id] = req_state - # Add request to stage 0 (Orchestrator handles all stage transitions) - await self.engine.add_request_async( - request_id=request_id, - prompt=prompt, - sampling_params_list=sampling_params_list, - final_stage_id=final_stage_id_for_e2e, - ) + # Add request(s) to stage 0. For streaming inputs, submit + # chunks incrementally through streaming_update. + if isinstance(prompt, AsyncGenerator): + input_stream_task = await self._add_streaming_input_request( + request_id=request_id, + input_stream=prompt, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id_for_e2e, + ) + else: + await self.engine.add_request_async( + request_id=request_id, + prompt=prompt, + sampling_params_list=sampling_params_list, + final_stage_id=final_stage_id_for_e2e, + ) submit_ts = time.time() req_state.metrics.stage_first_ts[0] = submit_ts req_start_ts[request_id] = submit_ts @@ -243,9 +257,118 @@ async def generate( self._log_summary_and_cleanup(request_id) except (asyncio.CancelledError, GeneratorExit): + if input_stream_task is not None and not input_stream_task.done(): + input_stream_task.cancel() await self.abort(request_id) logger.info(f"[AsyncOmni] Request {request_id} aborted.") raise + except Exception as e: + await self.abort(request_id) + logger.info(f"[AsyncOmni] Request {request_id} failed (input error): {e}") + raise + + async def _add_streaming_input_request( + self, + *, + request_id: str, + input_stream: AsyncGenerator[StreamingInput, None], + sampling_params_list: Sequence[OmniSamplingParams], + final_stage_id: int, + ) -> asyncio.Task: + """Submit a streaming input generator as incremental stage-0 updates.""" + if not sampling_params_list: + raise ValueError("sampling_params_list cannot be empty for streaming input") + # only check thinker's sampling params now + stage0_params = sampling_params_list[0] + self._validate_streaming_input_sampling_params(stage0_params) + + req_state = self.request_states[request_id] + + if not stage0_params.skip_clone: + stage0_params = stage0_params.clone() + stage0_params.skip_clone = True + stage0_params.output_kind = RequestOutputKind.DELTA + + has_submitted_first_chunk = False + + async def handle_inputs() -> None: + nonlocal has_submitted_first_chunk + cancelled = False + try: + async for chunk in input_stream: + chunk_params = getattr(chunk, "sampling_params", None) or stage0_params + self._validate_streaming_input_sampling_params(chunk_params) + chunk_sampling_params_list = list(sampling_params_list) + chunk_sampling_params_list[0] = chunk_params + chunk_prompt = chunk.prompt + prompt_text, _, _ = extract_prompt_components(self.model_config, chunk_prompt) + + if not has_submitted_first_chunk: + await self.engine.add_request_async( + request_id=request_id, + prompt=chunk_prompt, + prompt_text=prompt_text, + sampling_params_list=chunk_sampling_params_list, + final_stage_id=final_stage_id, + resumable=True, + ) + has_submitted_first_chunk = True + else: + await self.engine.add_streaming_update_async( + request_id=request_id, + prompt=chunk_prompt, + prompt_text=prompt_text, + sampling_params_list=chunk_sampling_params_list, + final_stage_id=final_stage_id, + resumable=True, + ) + except (asyncio.CancelledError, GeneratorExit): + cancelled = True + except Exception as error: + await req_state.queue.put({"request_id": request_id, "error": error}) + finally: + if not cancelled: + # Send empty final request to indicate that inputs have + # finished. Don't send if canceled (session was aborted). + final_sampling_params_list = list(sampling_params_list) + final_sampling_params_list[0] = stage0_params + final_prompt = TokensPrompt(prompt_token_ids=[0]) + + if has_submitted_first_chunk: + await self.engine.add_streaming_update_async( + request_id=request_id, + prompt=final_prompt, + prompt_text=None, + sampling_params_list=final_sampling_params_list, + final_stage_id=final_stage_id, + resumable=False, + ) + else: + await self.engine.add_request_async( + request_id=request_id, + prompt=final_prompt, + prompt_text=None, + sampling_params_list=final_sampling_params_list, + final_stage_id=final_stage_id, + resumable=False, + ) + + input_stream_task = asyncio.create_task(handle_inputs()) + req_state.input_stream_task = input_stream_task + return input_stream_task + + @staticmethod + def _validate_streaming_input_sampling_params(params: OmniSamplingParams) -> None: + if ( + not isinstance(params, SamplingParams) + or params.n > 1 + or params.output_kind == RequestOutputKind.FINAL_ONLY + or params.stop + ): + raise ValueError( + "Input streaming is currently supported only for SamplingParams " + "with n == 1, output_kind != FINAL_ONLY, and without stop strings." + ) async def encode( self, diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index d832b2726cf..0ffe33abde2 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -52,6 +52,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.orca_metrics import metrics_header +from vllm.entrypoints.openai.realtime.connection import RealtimeConnection +from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses from vllm.entrypoints.openai.server_utils import get_uvicorn_log_config from vllm.entrypoints.openai.speech_to_text.serving import ( @@ -803,6 +805,11 @@ async def omni_init_app_state( state.openai_streaming_speech = OmniStreamingSpeechHandler( speech_service=state.openai_serving_speech, ) + state.openai_serving_realtime = OpenAIServingRealtime( + engine_client=engine_client, + models=state.openai_serving_models, + request_logger=request_logger, + ) state.openai_serving_video = OmniOpenAIServingVideo( engine_client, @@ -1161,6 +1168,19 @@ async def streaming_speech(websocket: WebSocket): await handler.handle_session(websocket) +@router.websocket("/v1/realtime") +async def realtime_websocket(websocket: WebSocket): + """WebSocket endpoint for OpenAI-style realtime interactions.""" + serving = getattr(websocket.app.state, "openai_serving_realtime", None) + if serving is None: + await websocket.accept() + await websocket.send_json({"type": "error", "error": "Realtime API is not available", "code": "unsupported"}) + await websocket.close() + return + connection = RealtimeConnection(websocket, serving) + await connection.handle_connection() + + # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index ebe516e240b..04212ceeba8 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -3,10 +3,12 @@ # Copyright 2025 The Qwen team. """Inference-only Qwen3-Omni-Moe unified model (thinker + talker + code2wav).""" -from collections.abc import Iterable +import asyncio +from collections.abc import AsyncGenerator, Iterable from functools import cached_property from typing import Any +import numpy as np import torch import torch.nn as nn from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( @@ -15,10 +17,12 @@ Qwen3OmniMoeTalkerConfig, Qwen3OmniMoeThinkerConfig, ) -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.inputs.data import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal, SupportsPP +from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal, SupportsPP, SupportsRealtime +from vllm.model_executor.models.qwen3_asr_realtime import Qwen3ASRRealtimeBuffer from vllm.model_executor.models.qwen3_omni_moe_thinker import ( Qwen3OmniMoeConditionalGenerationMixin, ) @@ -26,6 +30,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.sequence import IntermediateTensors +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.transformers_utils.processor import cached_processor_from_config from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler @@ -34,6 +40,7 @@ from vllm_omni.model_executor.models.output_templates import OmniOutput from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker import ( Qwen3OmniMoeThinkerDummyInputsBuilder, + Qwen3OmniMoeThinkerForConditionalGeneration, Qwen3OmniMoeThinkerMultiModalProcessor, Qwen3OmniMoeThinkerProcessingInfo, ) @@ -70,7 +77,13 @@ dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder, ) class Qwen3OmniMoeForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, Qwen3OmniMoeConditionalGenerationMixin, CustomProcessMixin, SupportsMRoPE + nn.Module, + SupportsMultiModal, + SupportsPP, + Qwen3OmniMoeConditionalGenerationMixin, + CustomProcessMixin, + SupportsMRoPE, + SupportsRealtime, ): """ Unified Qwen3 Omni MoE model combining thinker, talker, and code2wav. @@ -84,6 +97,8 @@ class Qwen3OmniMoeForConditionalGeneration( Set `model_stage` in vllm_config to one of: "thinker", "talker", "code2wav" """ + realtime_max_tokens = 64 + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.have_multimodal_outputs = True @@ -191,6 +206,46 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.thinker.make_empty_intermediate_tensors if self.model_stage == "thinker" else lambda: None ) + @classmethod + async def buffer_realtime_audio( + cls, + audio_stream: AsyncGenerator[np.ndarray, None], + input_stream: asyncio.Queue[list[int]], + model_config: ModelConfig, + ) -> AsyncGenerator[PromptType, None]: + processor = cached_processor_from_config(model_config) + feature_extractor = processor.feature_extractor + sampling_rate = feature_extractor.sampling_rate + tokenizer = cached_tokenizer_from_config(model_config) + + # Use a small segment size for low-latency streaming. + segment_duration_s = 5.0 + buffer = Qwen3ASRRealtimeBuffer( + sampling_rate=sampling_rate, + segment_duration_s=segment_duration_s, + ) + + audio_placeholder = Qwen3OmniMoeThinkerForConditionalGeneration.get_placeholder_str("audio", 0) + prompt_template = f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n<|im_start|>assistant\n" + + prompt_token_ids = tokenizer.encode(prompt_template) + + async for audio_chunk in audio_stream: + buffer.write_audio(audio_chunk) + + while (segment := buffer.read_audio()) is not None: + yield TokensPrompt( + prompt_token_ids=prompt_token_ids, + multi_modal_data={"audio": segment}, + ) + + remaining = buffer.flush() + if remaining is not None and len(remaining) > 0: + yield TokensPrompt( + prompt_token_ids=prompt_token_ids, + multi_modal_data={"audio": remaining}, + ) + # ==================== Device utilities ==================== @staticmethod