diff --git a/docs/serving/speech_api.md b/docs/serving/speech_api.md index 2f069529876..c346b5b9745 100644 --- a/docs/serving/speech_api.md +++ b/docs/serving/speech_api.md @@ -164,6 +164,42 @@ curl -X POST http://localhost:8091/v1/audio/voices \ -F "name=custom_voice_1" ``` +## Streaming Text Input (WebSocket) + +The `/v1/audio/speech/stream` WebSocket endpoint accepts text incrementally and generates audio per sentence as boundaries are detected. + +> Note: text input is always streamed incrementally. Audio output remains sentence-scoped: +> use `stream_audio=false` for one binary frame per sentence, or `stream_audio=true` for one or more PCM chunks per sentence. + +### WebSocket Protocol + +Client -> Server: + +| Message | Description | +|---------|-------------| +| `{"type": "session.config", ...}` | Session configuration (sent once, first message) | +| `{"type": "input.text", "text": "..."}` | Text chunk | +| `{"type": "input.done"}` | End of input, flushes remaining buffer | + +Server -> Client: + +| Message | Description | +|---------|-------------| +| `{"type": "audio.start", "sentence_index": 0, "sentence_text": "...", "format": "pcm", "sample_rate": 24000}` | Audio generation starting for a sentence | +| Binary frame | Raw audio bytes (one or more PCM chunks when `stream_audio=true`) | +| `{"type": "audio.done", "sentence_index": 0, "total_bytes": 96000, "error": false}` | Audio complete for a sentence | +| `{"type": "session.done", "total_sentences": N}` | Session complete | +| `{"type": "error", "message": "..."}` | Non-fatal error | + +### Session Config Parameters + +All REST API parameters are supported, plus: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `stream_audio` | bool | false | Stream one or more PCM chunks per sentence over WebSocket | +| `split_granularity` | string | "sentence" | Text splitting granularity | + ```bash DELETE /v1/audio/voices/{name} diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md index 99442c29e2b..4709c9d4218 100644 --- a/examples/online_serving/qwen3_tts/README.md +++ b/examples/online_serving/qwen3_tts/README.md @@ -349,6 +349,42 @@ curl -X POST http://localhost:8091/v1/audio/speech \ - `speed` adjustment is not supported when streaming. - Requires the server stage config to have `async_chunk: true` (default in `qwen3_tts.yaml`). +## Streaming Text Input (WebSocket) + +The `/v1/audio/speech/stream` WebSocket endpoint accepts text incrementally, buffers and splits it at sentence boundaries, and generates audio per sentence. + +When `stream_audio=true`, each sentence is emitted as `audio.start`, one or more binary PCM frames, and `audio.done`. + +### Quick Start + +```bash +python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." + +python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." \ + --simulate-stt --stt-delay 0.1 +``` + +### WebSocket Protocol + +Client -> Server: + +```jsonc +{"type": "session.config", "voice": "Vivian", "task_type": "CustomVoice", "language": "Auto", "split_granularity": "sentence", "stream_audio": true, "response_format": "pcm"} +{"type": "input.text", "text": "Hello, how are you? "} +{"type": "input.done"} +``` + +Server -> Client: + +```jsonc +{"type": "audio.start", "sentence_index": 0, "sentence_text": "Hello, how are you?", "format": "pcm", "sample_rate": 24000} +// binary PCM frame(s) +{"type": "audio.done", "sentence_index": 0, "total_bytes": 96000, "error": false} +{"type": "session.done", "total_sentences": 1} +``` + ## Limitations - **Single request**: Batch processing is not yet optimized for online serving. diff --git a/examples/online_serving/qwen3_tts/streaming_speech_client.py b/examples/online_serving/qwen3_tts/streaming_speech_client.py new file mode 100644 index 00000000000..785c6a0e8d7 --- /dev/null +++ b/examples/online_serving/qwen3_tts/streaming_speech_client.py @@ -0,0 +1,249 @@ +"""WebSocket client for streaming text-input TTS. + +Connects to the /v1/audio/speech/stream endpoint, sends text incrementally +(simulating real-time STT output), and saves per-sentence audio files. + +Usage: + # Send full text at once + python streaming_speech_client.py --text "Hello world. How are you? I am fine." + + # Simulate STT: send text word-by-word with delay + python streaming_speech_client.py \ + --text "Hello world. How are you? I am fine." \ + --simulate-stt --stt-delay 0.1 + + # VoiceDesign task + python streaming_speech_client.py \ + --text "Today is a great day. The weather is nice." \ + --task-type VoiceDesign \ + --instructions "A cheerful young female voice" + + # Base task (voice cloning) + python streaming_speech_client.py \ + --text "Hello world. How are you?" \ + --task-type Base \ + --ref-audio /path/to/reference.wav \ + --ref-text "Transcript of reference audio" + +Requirements: + pip install websockets +""" + +import argparse +import asyncio +import json +import os + +try: + import websockets +except ImportError: + print("Please install websockets: pip install websockets") + raise SystemExit(1) + + +async def stream_tts( + url: str, + text: str, + config: dict, + output_dir: str, + simulate_stt: bool = False, + stt_delay: float = 0.1, +) -> None: + """Connect to the streaming TTS endpoint and process audio responses.""" + os.makedirs(output_dir, exist_ok=True) + + async with websockets.connect(url) as ws: + # 1. Send session config + config_msg = {"type": "session.config", **config} + await ws.send(json.dumps(config_msg)) + print(f"Sent session config: {config}") + + # 2. Send text (either all at once or word-by-word) + async def send_text(): + if simulate_stt: + words = text.split(" ") + for i, word in enumerate(words): + chunk = word + (" " if i < len(words) - 1 else "") + await ws.send( + json.dumps( + { + "type": "input.text", + "text": chunk, + } + ) + ) + print(f" Sent: {chunk!r}") + await asyncio.sleep(stt_delay) + else: + await ws.send( + json.dumps( + { + "type": "input.text", + "text": text, + } + ) + ) + print(f"Sent full text: {text!r}") + + # 3. Signal end of input + await ws.send(json.dumps({"type": "input.done"})) + print("Sent input.done") + + # Run sender and receiver concurrently + sender_task = asyncio.create_task(send_text()) + + response_format = config.get("response_format", "wav") + current_sentence_index = 0 + current_chunks: list[bytes] = [] + + try: + while True: + message = await ws.recv() + + if isinstance(message, bytes): + current_chunks.append(message) + print(f" Received audio chunk for sentence {current_sentence_index}: {len(message)} bytes") + else: + # JSON frame + msg = json.loads(message) + msg_type = msg.get("type") + + if msg_type == "audio.start": + current_sentence_index = msg["sentence_index"] + current_chunks = [] + print(f" [sentence {msg['sentence_index']}] Generating: {msg['sentence_text']!r}") + elif msg_type == "audio.done": + filename = os.path.join( + output_dir, + f"sentence_{msg['sentence_index']:03d}.{response_format}", + ) + with open(filename, "wb") as f: + f.write(b"".join(current_chunks)) + print( + f" [sentence {msg['sentence_index']}] Done" + f" bytes={msg.get('total_bytes', len(b''.join(current_chunks)))}" + f" error={msg.get('error', False)}" + f" -> {filename}" + ) + current_chunks = [] + elif msg_type == "session.done": + print(f"\nSession complete: {msg['total_sentences']} sentence(s) generated") + break + elif msg_type == "error": + print(f" ERROR: {msg['message']}") + else: + print(f" Unknown message: {msg}") + finally: + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass # Task cancellation is expected during shutdown + + print(f"\nAudio files saved to: {output_dir}/") + + +def main(): + parser = argparse.ArgumentParser(description="Streaming text-input TTS client") + parser.add_argument( + "--url", + default="ws://localhost:8000/v1/audio/speech/stream", + help="WebSocket endpoint URL", + ) + parser.add_argument( + "--text", + required=True, + help="Text to synthesize", + ) + parser.add_argument( + "--output-dir", + default="streaming_tts_output", + help="Directory to save audio files (default: streaming_tts_output)", + ) + + # Session config options + parser.add_argument("--model", default=None, help="Model name") + parser.add_argument("--voice", default="Vivian", help="Speaker voice") + parser.add_argument( + "--task-type", + default="CustomVoice", + choices=["CustomVoice", "VoiceDesign", "Base"], + help="TTS task type", + ) + parser.add_argument("--language", default="Auto", help="Language") + parser.add_argument("--instructions", default=None, help="Voice style instructions") + parser.add_argument( + "--response-format", + default="wav", + choices=["wav", "pcm", "flac", "mp3", "aac", "opus"], + help="Audio format", + ) + parser.add_argument( + "--stream-audio", + action="store_true", + help="Receive one or more PCM chunks per sentence (requires --response-format pcm)", + ) + parser.add_argument("--speed", type=float, default=1.0, help="Playback speed (0.25-4.0)") + parser.add_argument("--max-new-tokens", type=int, default=None, help="Max tokens") + + # Base task options + parser.add_argument("--ref-audio", default=None, help="Reference audio") + parser.add_argument("--ref-text", default=None, help="Reference text") + parser.add_argument( + "--x-vector-only-mode", + action="store_true", + default=False, + help="Speaker embedding only mode", + ) + + # STT simulation + parser.add_argument( + "--simulate-stt", + action="store_true", + help="Simulate STT by sending text word-by-word", + ) + parser.add_argument( + "--stt-delay", + type=float, + default=0.1, + help="Delay between words in STT simulation (seconds)", + ) + + args = parser.parse_args() + + # Build session config (only include non-None values) + config = {} + for key in [ + "model", + "voice", + "task_type", + "language", + "instructions", + "response_format", + "speed", + "max_new_tokens", + "ref_audio", + "ref_text", + ]: + val = getattr(args, key.replace("-", "_"), None) + if val is not None: + config[key] = val + if args.stream_audio: + config["stream_audio"] = True + if args.x_vector_only_mode: + config["x_vector_only_mode"] = True + + asyncio.run( + stream_tts( + url=args.url, + text=args.text, + config=config, + output_dir=args.output_dir, + simulate_stt=args.simulate_stt, + stt_delay=args.stt_delay, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/online_serving/test_qwen3_tts_websocket.py b/tests/e2e/online_serving/test_qwen3_tts_websocket.py new file mode 100644 index 00000000000..df051460119 --- /dev/null +++ b/tests/e2e/online_serving/test_qwen3_tts_websocket.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E online tests for Qwen3-TTS WebSocket streaming speech. +""" + +import asyncio +import json +import os +from pathlib import Path + +import pytest +import websockets + +from tests.conftest import OmniServer +from tests.utils import hardware_test + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" + + +def get_stage_config() -> str: + return str( + Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "qwen3_tts.yaml" + ) + + +@pytest.fixture(scope="module") +def omni_server(): + stage_config_path = get_stage_config() + + with OmniServer( + MODEL, + [ + "--stage-configs-path", + stage_config_path, + "--stage-init-timeout", + "120", + "--trust-remote-code", + "--enforce-eager", + "--disable-log-stats", + ], + env_dict={"VLLM_DISABLE_COMPILE_CACHE": "1"}, + ) as server: + yield server + + +async def _run_ws_session(host: str, port: int) -> dict: + uri = f"ws://{host}:{port}/v1/audio/speech/stream" + starts: list[dict] = [] + dones: list[dict] = [] + chunk_lengths: dict[int, list[int]] = {} + session_done: dict | None = None + + async with websockets.connect(uri, max_size=None) as ws: + await ws.send( + json.dumps( + { + "type": "session.config", + "model": MODEL, + "voice": "vivian", + "language": "English", + "response_format": "pcm", + "stream_audio": True, + } + ) + ) + await ws.send( + json.dumps( + { + "type": "input.text", + "text": ( + "Hello, this is a websocket streaming test for Qwen three TTS, " + "and this sentence is intentionally long enough to produce audio chunks. " + "This is the second sentence." + ), + } + ) + ) + await ws.send(json.dumps({"type": "input.done"})) + + while True: + message = await asyncio.wait_for(ws.recv(), timeout=180) + if isinstance(message, bytes): + if not starts: + raise AssertionError("Received audio bytes before audio.start") + sentence_index = starts[-1]["sentence_index"] + chunk_lengths.setdefault(sentence_index, []).append(len(message)) + continue + + payload = json.loads(message) + msg_type = payload.get("type") + if msg_type == "audio.start": + starts.append(payload) + chunk_lengths.setdefault(payload["sentence_index"], []) + elif msg_type == "audio.done": + dones.append(payload) + elif msg_type == "session.done": + session_done = payload + break + elif msg_type == "error": + raise AssertionError(f"WebSocket error: {payload['message']}") + else: + raise AssertionError(f"Unexpected WebSocket message: {payload}") + + return { + "starts": starts, + "dones": dones, + "chunk_lengths": chunk_lengths, + "session_done": session_done, + } + + +class TestQwen3TTSWebSocket: + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=4) + def test_streaming_pcm_output(self, omni_server) -> None: + result = asyncio.run(_run_ws_session(omni_server.host, omni_server.port)) + + starts = result["starts"] + dones = result["dones"] + chunk_lengths = result["chunk_lengths"] + session_done = result["session_done"] + + assert session_done is not None + assert session_done["total_sentences"] == 2 + assert len(starts) == 2 + assert len(dones) == 2 + + for idx, start in enumerate(starts): + assert start["type"] == "audio.start" + assert start["sentence_index"] == idx + assert start["format"] == "pcm" + assert start["sample_rate"] == 24000 + assert start["sentence_text"] + + for done in dones: + sentence_index = done["sentence_index"] + total_bytes = done["total_bytes"] + assert done["error"] is False + assert total_bytes > 0 + assert chunk_lengths[sentence_index], f"Expected binary PCM frames for sentence {sentence_index}" + assert sum(chunk_lengths[sentence_index]) == total_bytes diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index 7c520e88963..7e35fec0dc3 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -1,19 +1,23 @@ # tests/entrypoints/openai/test_serving_speech.py +import asyncio import logging import os from inspect import Signature, signature from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest import torch -from fastapi import FastAPI, HTTPException, UploadFile +from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.params import File, Form +from fastapi.responses import JSONResponse from fastapi.testclient import TestClient from pydantic import ValidationError from pytest_mock import MockerFixture +from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse +from vllm_omni.entrypoints.openai import api_server as api_server_module from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.protocol.audio import CreateAudio, OpenAICreateSpeechRequest from vllm_omni.entrypoints.openai.serving_speech import ( @@ -1052,4 +1056,33 @@ async def test_omni_model_includes_generate(self): omni.stage_list = [stage] tasks = await omni.get_supported_tasks() assert "generate" in tasks - assert "speech" in tasks + + +def test_api_server_create_speech_wraps_error_response_status(): + handler = MagicMock() + handler.create_speech = AsyncMock( + return_value=ErrorResponse( + error=ErrorInfo(message="bad request", type="BadRequestError", param=None, code=400), + ) + ) + + app = FastAPI() + app.state.openai_serving_speech = handler + scope = { + "type": "http", + "app": app, + "method": "POST", + "path": "/v1/audio/speech", + "headers": [], + "query_string": b"", + "client": ("127.0.0.1", 12345), + "server": ("testserver", 80), + "scheme": "http", + } + raw_request = Request(scope) + request = OpenAICreateSpeechRequest(input="Hello") + + response = asyncio.run(api_server_module.create_speech(request, raw_request)) + + assert isinstance(response, JSONResponse) + assert response.status_code == 400 diff --git a/tests/entrypoints/openai_api/test_serving_speech_stream.py b/tests/entrypoints/openai_api/test_serving_speech_stream.py new file mode 100644 index 00000000000..bd136ac7272 --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_speech_stream.py @@ -0,0 +1,387 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI, WebSocket +from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect + +from vllm_omni.entrypoints.openai import serving_speech_stream as streaming_speech_module +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _build_test_app(speech_service=None, *, idle_timeout=30.0, config_timeout=10.0): + if speech_service is None: + speech_service = MagicMock(spec=OmniOpenAIServingSpeech) + speech_service._generate_audio_bytes = AsyncMock(return_value=(b"RIFF" + b"\x00" * 32, "audio/wav")) + speech_service._prepare_speech_generation = AsyncMock(return_value=("req-1", object(), {})) + + async def mock_generate_pcm_chunks(_generator, _request_id): + for chunk in (b"\x01\x02", b"\x03\x04\x05"): + yield chunk + + speech_service._generate_pcm_chunks = mock_generate_pcm_chunks + speech_service.engine_client = MagicMock() + speech_service.engine_client.abort = AsyncMock() + + handler = OmniStreamingSpeechHandler( + speech_service=speech_service, + idle_timeout=idle_timeout, + config_timeout=config_timeout, + ) + app = FastAPI() + + @app.websocket("/v1/audio/speech/stream") + async def ws_endpoint(websocket: WebSocket): + await handler.handle_session(websocket) + + return app, speech_service + + +class TestStreamingSpeechWebSocket: + def test_non_streaming_single_frame(self): + app, speech_service = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian"}) + ws.send_json({"type": "input.text", "text": "Hello world. "}) + + start = ws.receive_json() + assert start["type"] == "audio.start" + assert start["sentence_index"] == 0 + assert start["sentence_text"] == "Hello world." + assert start["format"] == "wav" + + audio = ws.receive_bytes() + assert audio.startswith(b"RIFF") + + done = ws.receive_json() + assert done == {"type": "audio.done", "sentence_index": 0, "total_bytes": len(audio), "error": False} + + ws.send_json({"type": "input.done"}) + session_done = ws.receive_json() + assert session_done == {"type": "session.done", "total_sentences": 1} + + assert speech_service._generate_audio_bytes.await_count == 1 + + def test_streaming_multiple_binary_frames(self): + captured_requests = [] + + speech_service = MagicMock(spec=OmniOpenAIServingSpeech) + speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav")) + speech_service.engine_client = MagicMock() + speech_service.engine_client.abort = AsyncMock() + + async def mock_prepare_speech_generation(request): + captured_requests.append(request) + return "req-stream", object(), {} + + speech_service._prepare_speech_generation = mock_prepare_speech_generation + + async def mock_generate_pcm_chunks(_generator, _request_id): + for chunk in (b"\x01\x02", b"\x03\x04\x05", b"\x06"): + yield chunk + + speech_service._generate_pcm_chunks = mock_generate_pcm_chunks + app, _ = _build_test_app(speech_service) + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + "stream_audio": True, + "response_format": "pcm", + "initial_codec_chunk_frames": 12, + } + ) + ws.send_json({"type": "input.text", "text": "Hello world. "}) + + start = ws.receive_json() + assert start["type"] == "audio.start" + assert start["format"] == "pcm" + assert start["sample_rate"] == 24000 + + assert ws.receive_bytes() == b"\x01\x02" + assert ws.receive_bytes() == b"\x03\x04\x05" + assert ws.receive_bytes() == b"\x06" + + done = ws.receive_json() + assert done == {"type": "audio.done", "sentence_index": 0, "total_bytes": 6, "error": False} + + ws.send_json({"type": "input.done"}) + assert ws.receive_json() == {"type": "session.done", "total_sentences": 1} + + assert len(captured_requests) == 1 + assert captured_requests[0].stream is True + assert captured_requests[0].response_format == "pcm" + assert captured_requests[0].initial_codec_chunk_frames == 12 + assert speech_service._generate_audio_bytes.await_count == 0 + + def test_flush_on_input_done(self): + app, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian"}) + ws.send_json({"type": "input.text", "text": "Hello world without punctuation"}) + ws.send_json({"type": "input.done"}) + + assert ws.receive_json()["type"] == "audio.start" + assert ws.receive_bytes() + assert ws.receive_json() == { + "type": "audio.done", + "sentence_index": 0, + "total_bytes": 36, + "error": False, + } + assert ws.receive_json() == {"type": "session.done", "total_sentences": 1} + + def test_invalid_streaming_config(self): + app, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + "stream_audio": True, + "response_format": "wav", + } + ) + error = ws.receive_json() + assert error["type"] == "error" + assert "response_format='pcm'" in error["message"] + + def test_empty_input_text_emits_no_audio(self): + app, speech_service = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian"}) + ws.send_json({"type": "input.text", "text": ""}) + ws.send_json({"type": "input.done"}) + + assert ws.receive_json() == {"type": "session.done", "total_sentences": 0} + + assert speech_service._generate_audio_bytes.await_count == 0 + + def test_multiple_sentences_increment_indices(self): + app, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian"}) + ws.send_json({"type": "input.text", "text": "First sentence. Second sentence. "}) + + first_start = ws.receive_json() + assert first_start["sentence_index"] == 0 + ws.receive_bytes() + assert ws.receive_json() == { + "type": "audio.done", + "sentence_index": 0, + "total_bytes": 36, + "error": False, + } + + second_start = ws.receive_json() + assert second_start["sentence_index"] == 1 + ws.receive_bytes() + assert ws.receive_json() == { + "type": "audio.done", + "sentence_index": 1, + "total_bytes": 36, + "error": False, + } + + ws.send_json({"type": "input.done"}) + assert ws.receive_json() == {"type": "session.done", "total_sentences": 2} + + def test_unknown_message_type_keeps_session_open(self): + app, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian"}) + ws.send_json({"type": "unknown"}) + + error = ws.receive_json() + assert error == {"type": "error", "message": "Unknown message type: unknown"} + + ws.send_json({"type": "input.text", "text": "Hello world. "}) + assert ws.receive_json()["type"] == "audio.start" + ws.receive_bytes() + assert ws.receive_json() == { + "type": "audio.done", + "sentence_index": 0, + "total_bytes": 36, + "error": False, + } + + ws.send_json({"type": "input.done"}) + assert ws.receive_json() == {"type": "session.done", "total_sentences": 1} + + def test_config_timeout_closes_session(self): + app, _ = _build_test_app(config_timeout=0.01) + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + error = ws.receive_json() + assert error == {"type": "error", "message": "Timeout waiting for session.config"} + + def test_generation_error_marks_audio_done(self): + speech_service = MagicMock(spec=OmniOpenAIServingSpeech) + speech_service._generate_audio_bytes = AsyncMock(side_effect=RuntimeError("boom")) + speech_service._prepare_speech_generation = AsyncMock(return_value=("req-err", object(), {})) + speech_service._generate_pcm_chunks = AsyncMock() + speech_service.engine_client = MagicMock() + speech_service.engine_client.abort = AsyncMock() + app, _ = _build_test_app(speech_service) + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian"}) + ws.send_json({"type": "input.text", "text": "Hello world. "}) + + assert ws.receive_json()["type"] == "audio.start" + assert ws.receive_json() == {"type": "error", "message": "Generation failed for sentence 0: boom"} + assert ws.receive_json() == {"type": "audio.done", "sentence_index": 0, "total_bytes": 0, "error": True} + + ws.send_json({"type": "input.done"}) + assert ws.receive_json() == {"type": "session.done", "total_sentences": 1} + + def test_streaming_generation_error_marks_audio_done(self): + speech_service = MagicMock(spec=OmniOpenAIServingSpeech) + speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav")) + speech_service._prepare_speech_generation = AsyncMock(return_value=("req-stream-err", object(), {})) + speech_service.engine_client = MagicMock() + speech_service.engine_client.abort = AsyncMock() + + async def mock_generate_pcm_chunks(_generator, _request_id): + yield b"\x01\x02" + raise RuntimeError("stream boom") + + speech_service._generate_pcm_chunks = mock_generate_pcm_chunks + app, _ = _build_test_app(speech_service) + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json( + { + "type": "session.config", + "voice": "Vivian", + "stream_audio": True, + "response_format": "pcm", + } + ) + ws.send_json({"type": "input.text", "text": "Hello world. "}) + + assert ws.receive_json()["type"] == "audio.start" + assert ws.receive_bytes() == b"\x01\x02" + assert ws.receive_json() == { + "type": "error", + "message": "Generation failed for sentence 0: stream boom", + } + assert ws.receive_json() == { + "type": "audio.done", + "sentence_index": 0, + "total_bytes": 2, + "error": True, + } + + ws.send_json({"type": "input.done"}) + assert ws.receive_json() == {"type": "session.done", "total_sentences": 1} + + def test_invalid_input_text_type_returns_validation_error(self): + app, speech_service = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian"}) + ws.send_json({"type": "input.text", "text": 123}) + + assert ws.receive_json() == { + "type": "error", + "message": "input.text requires a string value", + } + + ws.send_json({"type": "input.done"}) + assert ws.receive_json() == {"type": "session.done", "total_sentences": 0} + + assert speech_service._generate_audio_bytes.await_count == 0 + + def test_input_text_message_too_large(self, monkeypatch): + monkeypatch.setattr(streaming_speech_module, "_MAX_INPUT_TEXT_MESSAGE_SIZE", 32) + app, speech_service = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian"}) + ws.send_json({"type": "input.text", "text": "x" * 128}) + + assert ws.receive_json() == { + "type": "error", + "message": "input.text message too large", + } + + ws.send_json({"type": "input.done"}) + assert ws.receive_json() == {"type": "session.done", "total_sentences": 0} + + assert speech_service._generate_audio_bytes.await_count == 0 + + def test_session_config_message_too_large(self, monkeypatch): + monkeypatch.setattr(streaming_speech_module, "_MAX_CONFIG_MESSAGE_SIZE", 64) + app, _ = _build_test_app() + + with TestClient(app) as client: + with client.websocket_connect("/v1/audio/speech/stream") as ws: + ws.send_json({"type": "session.config", "voice": "Vivian", "ref_audio": "x" * 512}) + + assert ws.receive_json() == { + "type": "error", + "message": "session.config message too large", + } + + def test_disconnect_aborts_streaming_request(self): + speech_service = MagicMock(spec=OmniOpenAIServingSpeech) + speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav")) + speech_service._prepare_speech_generation = AsyncMock(return_value=("req-abort", object(), {})) + speech_service.engine_client = MagicMock() + speech_service.engine_client.abort = AsyncMock() + + async def mock_generate_pcm_chunks(_generator, _request_id): + yield b"\x01\x02" + + speech_service._generate_pcm_chunks = mock_generate_pcm_chunks + handler = OmniStreamingSpeechHandler(speech_service=speech_service) + + websocket = MagicMock() + websocket.send_json = AsyncMock(side_effect=[None, WebSocketDisconnect()]) + websocket.send_bytes = AsyncMock(side_effect=WebSocketDisconnect()) + + config = MagicMock() + config.model = None + config.voice = "Vivian" + config.task_type = None + config.language = None + config.instructions = None + config.response_format = "pcm" + config.speed = 1.0 + config.max_new_tokens = None + config.initial_codec_chunk_frames = None + config.ref_audio = None + config.ref_text = None + config.x_vector_only_mode = None + config.stream_audio = True + + with pytest.raises(WebSocketDisconnect): + asyncio.run(handler._generate_and_send(websocket, config, "Hello world.", 0)) + + speech_service.engine_client.abort.assert_awaited_once_with("req-abort") + assert websocket.send_json.await_count == 2 diff --git a/tests/entrypoints/openai_api/test_text_splitter.py b/tests/entrypoints/openai_api/test_text_splitter.py new file mode 100644 index 00000000000..23d4d191fc2 --- /dev/null +++ b/tests/entrypoints/openai_api/test_text_splitter.py @@ -0,0 +1,253 @@ +"""Tests for SentenceSplitter used in streaming TTS input.""" + +import pytest + +from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter + +pytestmark = [pytest.mark.openai, pytest.mark.speech] + + +class TestSentenceSplitterEnglish: + """Tests for English sentence splitting.""" + + def test_single_sentence_no_boundary(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world") + assert result == [] + assert splitter.buffer == "Hello world" + + def test_single_sentence_with_boundary(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world. How are you?") + assert len(result) == 1 + assert result[0] == "Hello world." + + def test_multiple_sentences(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello. How are you? I am fine! ") + assert len(result) == 3 + assert result[0] == "Hello." + assert result[1] == "How are you?" + assert result[2] == "I am fine!" + + def test_exclamation_mark(self): + splitter = SentenceSplitter() + result = splitter.add_text("Wow, that is great! Tell me more.") + assert len(result) == 1 + assert result[0] == "Wow, that is great!" + + def test_question_mark(self): + splitter = SentenceSplitter() + result = splitter.add_text("Can you hear me? I hope so.") + assert len(result) == 1 + assert result[0] == "Can you hear me?" + + +class TestSentenceSplitterChinese: + """Tests for CJK sentence splitting.""" + + def test_chinese_period(self): + splitter = SentenceSplitter() + result = splitter.add_text("你好世界。你好吗") + assert len(result) == 1 + assert result[0] == "你好世界。" + + def test_chinese_exclamation(self): + splitter = SentenceSplitter() + result = splitter.add_text("太好了!谢谢你") + assert len(result) == 1 + assert result[0] == "太好了!" + + def test_chinese_question(self): + splitter = SentenceSplitter() + result = splitter.add_text("你是谁?我是小明") + assert len(result) == 1 + assert result[0] == "你是谁?" + + def test_chinese_comma_no_split(self): + """Chinese commas are clause-level and should not trigger a split.""" + splitter = SentenceSplitter() + result = splitter.add_text("你好,世界") + assert result == [] + assert splitter.buffer == "你好,世界" + + def test_chinese_semicolon_no_split(self): + """Chinese semicolons are clause-level and should not trigger a split.""" + splitter = SentenceSplitter() + result = splitter.add_text("第一点;第二点") + assert result == [] + assert splitter.buffer == "第一点;第二点" + + def test_chinese_multiple(self): + splitter = SentenceSplitter() + result = splitter.add_text("你好!你好吗?我很好。") + assert len(result) == 3 + assert result[0] == "你好!" + assert result[1] == "你好吗?" + assert result[2] == "我很好。" + + +class TestSentenceSplitterMixed: + """Tests for mixed-language sentence splitting.""" + + def test_mixed_english_chinese(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello世界。How are you? ") + assert len(result) == 2 + assert result[0] == "Hello世界。" + assert result[1] == "How are you?" + + +class TestSentenceSplitterIncremental: + """Tests for incremental (multi-chunk) text input.""" + + def test_accumulation_across_chunks(self): + splitter = SentenceSplitter() + # First chunk: no boundary + result1 = splitter.add_text("Hello ") + assert result1 == [] + + # Second chunk: completes a sentence + result2 = splitter.add_text("world. How") + assert len(result2) == 1 + assert result2[0] == "Hello world." + assert splitter.buffer == "How" + + def test_word_by_word(self): + splitter = SentenceSplitter() + words = ["Hello, ", "how ", "are ", "you? ", "I ", "am ", "fine."] + all_sentences = [] + for word in words: + all_sentences.extend(splitter.add_text(word)) + + assert len(all_sentences) == 1 + assert all_sentences[0] == "Hello, how are you?" + # "I am fine." stays in buffer (no trailing whitespace after period) + + def test_three_chunks(self): + splitter = SentenceSplitter() + splitter.add_text("The quick brown ") + splitter.add_text("fox jumps. ") + result = splitter.add_text("Over the lazy dog. ") + # "The quick brown fox jumps." should have been returned on second chunk + # "Over the lazy dog." on third chunk + assert len(result) == 1 + assert result[0] == "Over the lazy dog." + + +class TestSentenceSplitterFlush: + """Tests for flush behavior.""" + + def test_flush_returns_remaining(self): + splitter = SentenceSplitter() + splitter.add_text("Hello world") + result = splitter.flush() + assert result == "Hello world" + assert splitter.buffer == "" + + def test_flush_empty_buffer(self): + splitter = SentenceSplitter() + result = splitter.flush() + assert result is None + + def test_flush_after_sentence(self): + splitter = SentenceSplitter() + splitter.add_text("Hello world. Remaining text") + result = splitter.flush() + assert result == "Remaining text" + + def test_flush_whitespace_only(self): + splitter = SentenceSplitter() + splitter.add_text("Hello. ") + # "Hello." extracted, buffer is " " + result = splitter.flush() + # Whitespace-only should return None + assert result is None + + def test_flush_clears_buffer(self): + splitter = SentenceSplitter() + splitter.add_text("some text") + splitter.flush() + assert splitter.buffer == "" + # Second flush should return None + assert splitter.flush() is None + + +class TestSentenceSplitterEdgeCases: + """Edge case tests.""" + + def test_empty_input(self): + splitter = SentenceSplitter() + result = splitter.add_text("") + assert result == [] + assert splitter.buffer == "" + + def test_none_like_empty(self): + """Empty string should not affect buffer.""" + splitter = SentenceSplitter() + splitter.add_text("Hello") + splitter.add_text("") + assert splitter.buffer == "Hello" + + def test_only_punctuation(self): + splitter = SentenceSplitter() + result = splitter.add_text(". ") + # "." is 1 char, below default min_sentence_length of 2 + # It will be carried forward + assert result == [] + + def test_min_sentence_length(self): + splitter = SentenceSplitter(min_sentence_length=10) + result = splitter.add_text("Hi. Hello world. ") + # "Hi." is 3 chars (< 10), so it gets carried to "Hello world." + assert len(result) == 1 + assert "Hi." in result[0] + assert "Hello world." in result[0] + + def test_short_segments_are_carried_until_long_enough(self): + splitter = SentenceSplitter(min_sentence_length=10) + result = splitter.add_text("Hi. Ok. Hello there. ") + assert result == ["Hi.Ok.Hello there."] + assert splitter.buffer == "" + + def test_min_sentence_length_zero(self): + splitter = SentenceSplitter(min_sentence_length=0) + result = splitter.add_text("A. B. ") + assert len(result) == 2 + + def test_no_boundary_then_flush(self): + splitter = SentenceSplitter() + result = splitter.add_text("Hello world how are you") + assert result == [] + flushed = splitter.flush() + assert flushed == "Hello world how are you" + + def test_consecutive_punctuation(self): + splitter = SentenceSplitter() + result = splitter.add_text("Really?! Yes, really. ") + assert len(result) >= 1 + + def test_reuse_after_flush(self): + """Splitter can be reused after flush.""" + splitter = SentenceSplitter() + splitter.add_text("First session.") + splitter.flush() + + result = splitter.add_text("Second session. More text") + assert len(result) == 1 + assert result[0] == "Second session." + assert splitter.buffer == "More text" + + +class TestSentenceSplitterBufferLimit: + """Tests for buffer overflow protection.""" + + def test_buffer_overflow_raises(self): + from vllm_omni.entrypoints.openai.text_splitter import _MAX_BUFFER_SIZE + + splitter = SentenceSplitter() + # Fill buffer just under the limit + splitter.add_text("x" * (_MAX_BUFFER_SIZE - 1)) + # One more char should exceed the limit + with pytest.raises(ValueError, match="exceeded maximum size"): + splitter.add_text("xx") diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 98c9752d257..1507ce8c3f7 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -19,10 +19,9 @@ import httpx import vllm.envs as envs -from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile +from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, WebSocket from fastapi.responses import JSONResponse, StreamingResponse from PIL import Image -from pydantic import BaseModel, Field from starlette.datastructures import State from starlette.routing import Route from vllm import SamplingParams @@ -99,6 +98,7 @@ ) from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt from vllm_omni.lora.request import LoRARequest @@ -106,30 +106,6 @@ logger = init_logger(__name__) router = APIRouter() -profiler_router = APIRouter() - - -def _should_enable_profiler_endpoints(args: Namespace) -> bool: - # Check upstream vLLM's profiler_config - profiler_config = getattr(args, "profiler_config", None) - if profiler_config is not None: - # profiler_config exists, check if profiler is set - profiler = getattr(profiler_config, "profiler", None) - if profiler is not None: - return True - - # TODO: remove this env after refactoring torch profiler to CLI args - env_value = os.environ.get("VLLM_TORCH_PROFILER_DIR") - return env_value is not None - - -class ProfileRequest(BaseModel): - """Request model for profiling endpoints.""" - - stages: list[int] | None = Field( - default=None, - description="List of stage IDs to profile. If None, profiles all stages.", - ) def _remove_route_from_router( @@ -259,11 +235,6 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, await omni_init_app_state(engine_client, app.state, args) - # Conditionally register profiler endpoints based on config or env var - if _should_enable_profiler_endpoints(args): - logger.warning("Profiler endpoints are enabled. This should ONLY be used for local development!") - app.include_router(profiler_router) - vllm_config = await engine_client.get_vllm_config() # Check if pure diffusion mode (vllm_config will be None) @@ -294,7 +265,6 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs, - ssl_ciphers=args.ssl_ciphers, h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, h11_max_header_count=args.h11_max_header_count, **uvicorn_kwargs, @@ -742,6 +712,10 @@ async def omni_init_app_state( engine_client, state.openai_serving_models, request_logger=request_logger ) + state.openai_streaming_speech = OmniStreamingSpeechHandler( + speech_service=state.openai_serving_speech, + ) + state.openai_serving_video = OmniOpenAIServingVideo( engine_client, model_name=served_model_names[0] if served_model_names else None, @@ -889,113 +863,29 @@ async def list_voices(raw_request: Request): if handler is None: return base(raw_request).create_error_response(message="The model does not support Speech API") - # Get all speakers (both model built-in and uploaded) speakers = sorted(handler.supported_speakers) if handler.supported_speakers else [] + return JSONResponse(content={"voices": speakers}) - # Get uploaded speakers details - uploaded_speakers = [] - if hasattr(handler, "uploaded_speakers"): - for voice_name, info in handler.uploaded_speakers.items(): - uploaded_speakers.append( - { - "name": info.get("name", voice_name), - "consent": info.get("consent", ""), - "created_at": info.get("created_at", 0), - "file_size": info.get("file_size", 0), - "mime_type": info.get("mime_type", ""), - } - ) - - return JSONResponse(content={"voices": speakers, "uploaded_voices": uploaded_speakers}) - - -@router.post( - "/v1/audio/voices", - responses={ - HTTPStatus.OK.value: {"model": dict}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def upload_voice( - raw_request: Request, - audio_sample: UploadFile = File(...), - consent: str = Form(...), - name: str = Form(...), -): - """Upload a new voice sample for voice cloning. - - Uploads an audio file that can be used as a reference for voice cloning - in Base task TTS requests. The voice can then be referenced by name - in subsequent TTS requests. - - Args: - audio_sample: Audio file (max 10MB) - consent: Consent recording ID - name: Name for the new voice - raw_request: Raw FastAPI request - - Returns: - JSON response with voice information - """ - handler = Omnispeech(raw_request) - if handler is None: - return base(raw_request).create_error_response(message="The model does not support Speech API") - - try: - # Upload the voice - result = await handler.upload_voice(audio_sample, consent, name) - - return JSONResponse(content={"success": True, "voice": result}) - - except ValueError as e: - return base(raw_request).create_error_response(message=str(e)) - except Exception as e: - logger.exception(f"Failed to upload voice: {e}") - return base(raw_request).create_error_response(message=f"Failed to upload voice: {str(e)}") - - -@router.delete( - "/v1/audio/voices/{name}", - responses={ - HTTPStatus.OK.value: {"model": dict}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def delete_voice(name: str, raw_request: Request): - """Delete an uploaded voice. - - Deletes the voice sample and associated metadata. Also removes any - cached voice clone prompts for this voice. - Args: - name: Name of the voice to delete - raw_request: Raw FastAPI request +@router.websocket("/v1/audio/speech/stream") +async def streaming_speech(websocket: WebSocket): + """WebSocket endpoint for streaming text input TTS. - Returns: - JSON response indicating success or failure + Accepts text incrementally, splits at sentence boundaries, and + returns audio per sentence. See serving_speech_stream.py for protocol. """ - handler = Omnispeech(raw_request) + handler = getattr(websocket.app.state, "openai_streaming_speech", None) if handler is None: - return base(raw_request).create_error_response(message="The model does not support Speech API") - - try: - # Delete the voice - success = await handler.delete_voice(name) - if not success: - return JSONResponse( - content={"success": False, "error": f"Voice '{name}' not found"}, - status_code=HTTPStatus.NOT_FOUND.value, - ) - - return JSONResponse(content={"success": True, "message": f"Voice '{name}' deleted successfully"}) - - except ValueError as e: - return base(raw_request).create_error_response(message=str(e)) - except Exception as e: - logger.exception(f"Failed to delete voice '{name}': {e}") - return base(raw_request).create_error_response(message=f"Failed to delete voice: {str(e)}") + await websocket.accept() + await websocket.send_json( + { + "type": "error", + "message": "Streaming speech is not available", + } + ) + await websocket.close() + return + await handler.handle_session(websocket) # Health and Model endpoints for diffusion mode @@ -1637,58 +1527,6 @@ def apply_stage_default_sampling_params( setattr(sampling_params, param_name, param_value) -@profiler_router.post("/start_profile") -async def start_profile(raw_request: Request, request: ProfileRequest | None = None): - """Start profiling for the engine. - - Args: - request: Optional request body with stages to profile. - - stages: List of stage IDs to profile. If None, profiles all stages. - - Example: - POST /start_profile - {"stages": [0, 1]} # Profile only stages 0 and 1 - """ - try: - stages = request.stages if request else None - logger.info("Starting profiler for stages: %s", stages if stages else "all") - engine_client = raw_request.app.state.engine_client - result = await engine_client.start_profile(stages=stages) - logger.info("Profiler started.") - return JSONResponse(content=result) - except Exception as e: - logger.exception("Failed to start profiler: %s", e) - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to start profiler: {str(e)}" - ) - - -@profiler_router.post("/stop_profile") -async def stop_profile(raw_request: Request, request: ProfileRequest | None = None): - """Stop profiling for the engine. - - Args: - request: Optional request body with stages to stop profiling. - - stages: List of stage IDs to stop profiling. If None, stops all stages. - - Example: - POST /stop_profile - {"stages": [0, 1]} # Stop profiling only stages 0 and 1 - """ - try: - stages = request.stages if request else None - logger.info("Stopping profiler for stages: %s", stages if stages else "all") - engine_client = raw_request.app.state.engine_client - result = await engine_client.stop_profile(stages=stages) - logger.info("Profiler stopped.") - return JSONResponse(content=result) - except Exception as e: - logger.exception("Failed to stop profiler: %s", e) - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to stop profiler: {str(e)}" - ) - - async def _run_video_generation( request: VideoGenerationRequest, raw_request: Request, diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index 871787173af..eb32d7f0546 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -100,3 +100,52 @@ class Config: class AudioResponse(BaseModel): audio_data: bytes | str media_type: str + + +class StreamingSpeechSessionConfig(BaseModel): + """Configuration sent as the first WebSocket message for streaming TTS.""" + + model: str | None = None + voice: str | None = None + task_type: Literal["CustomVoice", "VoiceDesign", "Base"] | None = None + language: str | None = None + instructions: str | None = None + response_format: Literal["wav", "pcm", "flac", "mp3", "aac", "opus"] = "wav" + speed: float | None = Field(default=1.0, ge=0.25, le=4.0) + max_new_tokens: int | None = Field(default=None, ge=1) + initial_codec_chunk_frames: int | None = Field( + default=None, + ge=0, + description="Initial chunk size for reduced TTFA. Overrides stage config for this session.", + ) + ref_audio: str | None = None + ref_text: str | None = None + x_vector_only_mode: bool | None = None + stream_audio: bool = Field( + default=False, + description=( + "If true, send raw PCM audio chunks progressively over WebSocket. " + "Requires response_format='pcm'. Speed adjustment is not supported when streaming." + ), + ) + split_granularity: Literal["sentence", "clause"] = Field( + default="sentence", + description=( + "Text splitting granularity: 'sentence' splits on .!?。!?, " + "'clause' also splits on CJK commas , and semicolons ;." + ), + ) + + @model_validator(mode="after") + def validate_streaming_constraints(self) -> "StreamingSpeechSessionConfig": + if self.stream_audio: + if self.response_format != "pcm": + raise ValueError( + "WebSocket streaming audio (stream_audio=true) requires response_format='pcm'. " + f"Got response_format='{self.response_format}'." + ) + if self.speed is None: + self.speed = 1.0 + elif self.speed != 1.0: + raise ValueError("Speed adjustment is not supported when stream_audio=true. Set speed=1.0 or omit it.") + return self diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 7bcf75ace9d..6b565647420 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -20,6 +20,7 @@ from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.metadata_manager import MetadataManager from vllm_omni.entrypoints.openai.protocol.audio import ( + AudioResponse, CreateAudio, OpenAICreateSpeechRequest, ) @@ -636,6 +637,7 @@ async def _generate_pcm_chunks(self, generator, request_id: str): raise except Exception as e: logger.exception("Streaming speech generation failed for %s: %s", request_id, e) + raise @staticmethod def _extract_audio_output(res) -> tuple[dict | None, str | None]: @@ -650,7 +652,7 @@ def _extract_audio_output(res) -> tuple[dict | None, str | None]: mm = getattr(ro, "multimodal_output", None) if ro else None if not mm: return None, None - key = "audio" if "audio" in mm else None + key = "audio" if "audio" in mm else ("model_outputs" if "model_outputs" in mm else None) return mm, key def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: @@ -721,6 +723,95 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any return params + async def _prepare_speech_generation( + self, + request: OpenAICreateSpeechRequest, + ) -> tuple[str, Any, dict[str, Any]]: + if self.engine_client.errored: + raise self.engine_client.dead_error + + if self._is_tts: + validation_error = self._validate_tts_request(request) + if validation_error: + raise ValueError(validation_error) + + tts_params = self._build_tts_params(request) + if request.ref_audio is not None: + wav_list, sr = await self._resolve_ref_audio(request.ref_audio) + tts_params["ref_audio"] = [[wav_list, sr]] + + ph_len = self._estimate_prompt_len(tts_params) + prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params} + else: + tts_params = {} + prompt = {"prompt": request.input} + + request_id = f"speech-{random_uuid()}" + logger.info( + "TTS speech request %s: text=%r, task_type=%s", + request_id, + request.input[:50] + "..." if len(request.input) > 50 else request.input, + tts_params.get("task_type", ["unknown"])[0], + ) + + sampling_params_list = self.engine_client.default_sampling_params_list + generator = self.engine_client.generate( + prompt=prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=["audio"], + ) + return request_id, generator, tts_params + + async def _iter_pcm_audio_bytes(self, request: OpenAICreateSpeechRequest): + """Yield raw PCM bytes for a speech request as soon as chunks are decoded.""" + request_id, generator, _ = await self._prepare_speech_generation(request) + async for chunk in self._generate_pcm_chunks(generator, request_id): + yield chunk + + async def _generate_audio_bytes( + self, + request: OpenAICreateSpeechRequest, + ) -> tuple[bytes, str]: + request_id, generator, _ = await self._prepare_speech_generation(request) + + final_output: OmniRequestOutput | None = None + async for res in generator: + final_output = res + + if final_output is None: + raise ValueError("No output generated from the model.") + + audio_output, audio_key = self._extract_audio_output(final_output) + if audio_key is None: + raise ValueError("TTS model did not produce audio output.") + + audio_tensor = audio_output[audio_key] + sr_raw = audio_output.get("sr", 24000) + sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw + sample_rate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) + + if isinstance(audio_tensor, list): + import torch + + audio_tensor = torch.cat(audio_tensor, dim=-1) + if hasattr(audio_tensor, "float"): + audio_tensor = audio_tensor.float().detach().cpu().numpy() + + if audio_tensor.ndim > 1: + audio_tensor = audio_tensor.squeeze() + + audio_obj = CreateAudio( + audio_tensor=audio_tensor, + sample_rate=sample_rate, + response_format=request.response_format or "wav", + speed=request.speed or 1.0, + stream_format=request.stream_format, + base64_encode=False, + ) + audio_response: AudioResponse = self.create_audio(audio_obj) + return audio_response.audio_data, audio_response.media_type + async def create_speech( self, request: OpenAICreateSpeechRequest, @@ -750,89 +841,16 @@ async def create_speech( logger.error("Error with model %s", error_check_ret) return error_check_ret - if self.engine_client.errored: - raise self.engine_client.dead_error - - request_id = f"speech-{random_uuid()}" - try: - if self._is_tts: - # Validate TTS parameters - validation_error = self._validate_tts_request(request) - if validation_error: - return self.create_error_response(validation_error) - - tts_params = self._build_tts_params(request) - if request.ref_audio is not None: - wav_list, sr = await self._resolve_ref_audio(request.ref_audio) - tts_params["ref_audio"] = [[wav_list, sr]] - - # Prompt length must match model-side embeddings; values are placeholders. - ph_len = self._estimate_prompt_len(tts_params) - prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params} - else: - tts_params = {} - prompt = {"prompt": request.input} - - logger.info( - "TTS speech request %s: text=%r, task_type=%s", - request_id, - request.input[:50] + "..." if len(request.input) > 50 else request.input, - tts_params.get("task_type", ["unknown"])[0], - ) - - sampling_params_list = self.engine_client.default_sampling_params_list - - generator = self.engine_client.generate( - prompt=prompt, - request_id=request_id, - sampling_params_list=sampling_params_list, - output_modalities=["audio"], - ) - if request.stream: + request_id, generator, _ = await self._prepare_speech_generation(request) return StreamingResponse( self._generate_pcm_chunks(generator, request_id), media_type="audio/pcm", ) - # Non-streaming: collect final output - final_output: OmniRequestOutput | None = None - async for res in generator: - final_output = res - - if final_output is None: - return self.create_error_response("No output generated from the model.") - - audio_output, audio_key = self._extract_audio_output(final_output) - if audio_key is None: - return self.create_error_response("TTS model did not produce audio output.") - - audio_tensor = audio_output[audio_key] - sr_raw = audio_output.get("sr", 24000) - sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw - sample_rate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) - - # async_chunk mode accumulates chunks as a list; concat first. - if isinstance(audio_tensor, list): - import torch - - audio_tensor = torch.cat(audio_tensor, dim=-1) - if hasattr(audio_tensor, "float"): - audio_tensor = audio_tensor.float().detach().cpu().numpy() - if audio_tensor.ndim > 1: - audio_tensor = audio_tensor.squeeze() - - audio_obj = CreateAudio( - audio_tensor=audio_tensor, - sample_rate=sample_rate, - response_format=request.response_format or "wav", - speed=request.speed or 1.0, - stream_format=request.stream_format, - base64_encode=False, - ) - audio_response = self.create_audio(audio_obj) - return Response(content=audio_response.audio_data, media_type=audio_response.media_type) + audio_bytes, media_type = await self._generate_audio_bytes(request) + return Response(content=audio_bytes, media_type=media_type) except asyncio.CancelledError: return self.create_error_response("Client disconnected") diff --git a/vllm_omni/entrypoints/openai/serving_speech_stream.py b/vllm_omni/entrypoints/openai/serving_speech_stream.py new file mode 100644 index 00000000000..d88c88b651d --- /dev/null +++ b/vllm_omni/entrypoints/openai/serving_speech_stream.py @@ -0,0 +1,289 @@ +"""WebSocket handler for streaming text input TTS. + +Accepts text incrementally via WebSocket, buffers and splits at sentence +boundaries, and generates audio per sentence using the existing TTS pipeline. + +Protocol: + Client -> Server: + {"type": "session.config", ...} # Session config (sent once first) + {"type": "input.text", "text": "..."} # Text chunks + {"type": "input.done"} # End of input + + Server -> Client: + {"type": "audio.start", "sentence_index": 0, "sentence_text": "...", "format": "wav"} + + {"type": "audio.done", "sentence_index": 0} + {"type": "session.done", "total_sentences": N} + {"type": "error", "message": "..."} +""" + +import asyncio +import json +from contextlib import aclosing + +from fastapi import WebSocket, WebSocketDisconnect +from pydantic import ValidationError +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.protocol.audio import ( + OpenAICreateSpeechRequest, + StreamingSpeechSessionConfig, +) +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.text_splitter import ( + SPLIT_CLAUSE, + SPLIT_SENTENCE, + SentenceSplitter, +) + +logger = init_logger(__name__) + +_DEFAULT_IDLE_TIMEOUT = 30.0 # seconds +_DEFAULT_CONFIG_TIMEOUT = 10.0 # seconds +_PCM_SAMPLE_RATE = 24000 +_MAX_CONFIG_MESSAGE_SIZE = 4 * 1024 * 1024 # allow large ref_audio payloads +_MAX_INPUT_TEXT_MESSAGE_SIZE = 128 * 1024 + + +class OmniStreamingSpeechHandler: + """Handles WebSocket sessions for streaming text-input TTS. + + Each WebSocket connection is an independent session. Text arrives + incrementally, is split at sentence boundaries, and audio is generated + per sentence using the existing OmniOpenAIServingSpeech pipeline. + + Args: + speech_service: The existing TTS serving instance (reused for + validation and audio generation). + idle_timeout: Max seconds to wait for a message before closing. + config_timeout: Max seconds to wait for the initial session.config. + """ + + def __init__( + self, + speech_service: OmniOpenAIServingSpeech, + idle_timeout: float = _DEFAULT_IDLE_TIMEOUT, + config_timeout: float = _DEFAULT_CONFIG_TIMEOUT, + ) -> None: + self._speech_service = speech_service + self._idle_timeout = idle_timeout + self._config_timeout = config_timeout + + async def handle_session(self, websocket: WebSocket) -> None: + """Main session loop for a single WebSocket connection.""" + await websocket.accept() + + try: + # 1. Wait for session.config + config = await self._receive_config(websocket) + if config is None: + return # Error already sent, connection closing + + # Validate model if specified + if config.model and hasattr(self._speech_service, "_check_model"): + error = await self._speech_service._check_model( + OpenAICreateSpeechRequest(input="ping", model=config.model) + ) + if error is not None: + await self._send_error(websocket, str(error)) + return + + boundary_re = SPLIT_CLAUSE if config.split_granularity == "clause" else SPLIT_SENTENCE + splitter = SentenceSplitter(boundary_re=boundary_re) + sentence_index = 0 + + # 2. Receive text chunks until input.done + while True: + try: + raw = await asyncio.wait_for( + websocket.receive_text(), + timeout=self._idle_timeout, + ) + except asyncio.TimeoutError: + await self._send_error(websocket, "Idle timeout: no message received") + return + + if len(raw) > _MAX_INPUT_TEXT_MESSAGE_SIZE: + await self._send_error(websocket, "input.text message too large") + continue + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await self._send_error(websocket, "Invalid JSON message") + continue + + if not isinstance(msg, dict): + await self._send_error(websocket, "WebSocket messages must be JSON objects") + continue + + msg_type = msg.get("type") + + if msg_type == "input.text": + text = msg.get("text", "") + if not isinstance(text, str): + await self._send_error(websocket, "input.text requires a string value") + continue + sentences = splitter.add_text(text) + for sentence in sentences: + await self._generate_and_send(websocket, config, sentence, sentence_index) + sentence_index += 1 + + elif msg_type == "input.done": + # Flush remaining buffer + remaining = splitter.flush() + if remaining: + await self._generate_and_send(websocket, config, remaining, sentence_index) + sentence_index += 1 + + # Send session.done + await websocket.send_json( + { + "type": "session.done", + "total_sentences": sentence_index, + } + ) + return + + else: + await self._send_error( + websocket, + f"Unknown message type: {msg_type}", + ) + + except WebSocketDisconnect: + logger.info("Streaming speech: client disconnected") + except Exception as e: + logger.exception("Streaming speech session error: %s", e) + try: + await self._send_error(websocket, f"Internal error: {e}") + except Exception: + logger.debug("Failed to send error to streaming speech client", exc_info=True) + + async def _receive_config(self, websocket: WebSocket) -> StreamingSpeechSessionConfig | None: + """Wait for and validate the session.config message.""" + try: + raw = await asyncio.wait_for( + websocket.receive_text(), + timeout=self._config_timeout, + ) + except asyncio.TimeoutError: + await self._send_error(websocket, "Timeout waiting for session.config") + return None + + if len(raw) > _MAX_CONFIG_MESSAGE_SIZE: + await self._send_error(websocket, "session.config message too large") + return None + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + await self._send_error(websocket, "Invalid JSON in session.config") + return None + + if not isinstance(msg, dict): + await self._send_error(websocket, "session.config must be a JSON object") + return None + + if msg.get("type") != "session.config": + await self._send_error( + websocket, + f"Expected session.config, got: {msg.get('type')}", + ) + return None + + try: + config = StreamingSpeechSessionConfig(**{k: v for k, v in msg.items() if k != "type"}) + except ValidationError as e: + await self._send_error(websocket, f"Invalid session config: {e}") + return None + + return config + + async def _generate_and_send( + self, + websocket: WebSocket, + config: StreamingSpeechSessionConfig, + sentence_text: str, + sentence_index: int, + ) -> None: + """Generate audio for a single sentence and send it over WebSocket.""" + response_format = config.response_format or "wav" + + request = OpenAICreateSpeechRequest( + input=sentence_text, + model=config.model, + voice=config.voice, + task_type=config.task_type, + language=config.language, + instructions=config.instructions, + response_format=response_format, + speed=config.speed, + max_new_tokens=config.max_new_tokens, + initial_codec_chunk_frames=config.initial_codec_chunk_frames, + ref_audio=config.ref_audio, + ref_text=config.ref_text, + x_vector_only_mode=config.x_vector_only_mode, + stream=config.stream_audio, + ) + + start_payload = { + "type": "audio.start", + "sentence_index": sentence_index, + "sentence_text": sentence_text, + "format": response_format, + } + if config.stream_audio and response_format == "pcm": + start_payload["sample_rate"] = _PCM_SAMPLE_RATE + await websocket.send_json(start_payload) + + total_bytes = 0 + generation_failed = False + request_id = None + try: + if config.stream_audio: + request_id, generator, _ = await self._speech_service._prepare_speech_generation(request) + async with aclosing(self._speech_service._generate_pcm_chunks(generator, request_id)) as stream: + async for chunk in stream: + total_bytes += len(chunk) + await websocket.send_bytes(chunk) + else: + audio_bytes, _ = await self._speech_service._generate_audio_bytes(request) + total_bytes = len(audio_bytes) + await websocket.send_bytes(audio_bytes) + except WebSocketDisconnect: + if request_id is not None: + try: + await self._speech_service.engine_client.abort(request_id) + except Exception: + logger.debug("Failed to abort streaming speech request %s", request_id, exc_info=True) + raise + except Exception as e: + generation_failed = True + logger.error("Generation failed for sentence %d: %s", sentence_index, e) + await self._send_error(websocket, f"Generation failed for sentence {sentence_index}: {e}") + finally: + try: + await websocket.send_json( + { + "type": "audio.done", + "sentence_index": sentence_index, + "total_bytes": total_bytes, + "error": generation_failed, + } + ) + except Exception: + logger.debug("Failed to send audio.done for sentence %d", sentence_index, exc_info=True) + + @staticmethod + async def _send_error(websocket: WebSocket, message: str) -> None: + """Send an error message to the client.""" + try: + await websocket.send_json( + { + "type": "error", + "message": message, + } + ) + except Exception: + pass # Connection may already be closed; safe to ignore diff --git a/vllm_omni/entrypoints/openai/text_splitter.py b/vllm_omni/entrypoints/openai/text_splitter.py new file mode 100644 index 00000000000..9c02de1396b --- /dev/null +++ b/vllm_omni/entrypoints/openai/text_splitter.py @@ -0,0 +1,120 @@ +"""Multi-language sentence boundary detector for streaming TTS input. + +Buffers incoming text and splits at sentence boundaries (English and CJK), +yielding complete sentences for audio generation. +""" + +import re +from re import Pattern + +# Maximum buffer size (in characters) to prevent unbounded memory growth. +_MAX_BUFFER_SIZE = 100_000 # ~100 KB of text + +# Sentence-level: .!? + CJK sentence-ending 。!? +# NOTE: English requires trailing whitespace to confirm a boundary — +# end-of-string is NOT treated as a boundary (that is what flush() is for). +SPLIT_SENTENCE = re.compile( + r"(?<=[.!?])\s+" + r"|(?<=[。!?])" +) + +# Clause-level: adds CJK commas , and semicolons ; +SPLIT_CLAUSE = re.compile( + r"(?<=[.!?])\s+" + r"|(?<=[。!?,;])" +) + +# Default alias +_SENTENCE_BOUNDARY_RE = SPLIT_SENTENCE + + +class SentenceSplitter: + """Incremental sentence splitter for streaming text input. + + Buffers text and yields complete sentences when boundaries are detected. + Designed for TTS pipelines where text arrives incrementally (e.g., from STT). + + Args: + min_sentence_length: Minimum character length for a sentence. + Sentences shorter than this are kept in the buffer to avoid + splitting on abbreviations like "Dr." or "U.S.". + boundary_re: Custom compiled regex for sentence boundaries. + Use ``SPLIT_SENTENCE`` (default) for sentence-level splitting, + ``SPLIT_CLAUSE`` for finer-grained clause-level splitting, + or pass your own ``re.Pattern``. + """ + + def __init__( + self, + min_sentence_length: int = 2, + boundary_re: Pattern[str] | None = None, + ) -> None: + self._buffer: str = "" + self._min_sentence_length = min_sentence_length + self._boundary_re = boundary_re or _SENTENCE_BOUNDARY_RE + + @property + def buffer(self) -> str: + """Current buffered text.""" + return self._buffer + + def add_text(self, text: str) -> list[str]: + """Add text to the buffer and return any complete sentences. + + Args: + text: Incoming text chunk. + + Returns: + List of complete sentences extracted from the buffer. + May be empty if no sentence boundary was found. + + Raises: + ValueError: If the buffer exceeds the maximum size. + """ + if not text: + return [] + + self._buffer += text + if len(self._buffer) > _MAX_BUFFER_SIZE: + raise ValueError( + f"Text buffer exceeded maximum size ({_MAX_BUFFER_SIZE} chars). " + "Consider adding sentence-ending punctuation to your input." + ) + return self._extract_sentences() + + def flush(self) -> str | None: + """Flush remaining buffered text as a final sentence. + + Returns: + The remaining buffered text (stripped), or None if buffer is empty. + """ + remaining = self._buffer.strip() + self._buffer = "" + return remaining if remaining else None + + def _extract_sentences(self) -> list[str]: + """Split buffer at sentence boundaries, keeping incomplete text buffered.""" + parts = self._boundary_re.split(self._buffer) + + if len(parts) <= 1: + # No boundary found — keep everything in buffer + return [] + + sentences: list[str] = [] + carry = "" + # All parts except the last are complete sentences + for i in range(len(parts) - 1): + text = carry + parts[i] + carry = "" + stripped = text.strip() + if len(stripped) >= self._min_sentence_length: + sentences.append(stripped) + elif stripped: + # Too short (e.g. "Dr.") — carry forward to next part + carry = text + # else: empty, skip + + # Last part stays in buffer (may be incomplete) + self._buffer = carry + parts[-1] + + return sentences