diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index aba143c83c16..f15e638dfd21 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -54,6 +54,7 @@ Query, Request, UploadFile, + WebSocket, ) from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -1580,6 +1581,12 @@ async def openai_v1_audio_transcriptions( ) +@app.websocket("/v1/audio/transcriptions/stream") +async def openai_v1_audio_transcriptions_ws(ws: WebSocket): + """WebSocket endpoint for real-time streaming audio transcription.""" + await ws.app.state.openai_serving_transcription.handle_websocket(ws) + + @app.get("/v1/models", response_class=ORJSONResponse) async def available_models(): """Show available models. OpenAI-compatible endpoint.""" diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription.py b/python/sglang/srt/entrypoints/openai/serving_transcription.py index 5f98a6299931..1fe377642c2c 100644 --- a/python/sglang/srt/entrypoints/openai/serving_transcription.py +++ b/python/sglang/srt/entrypoints/openai/serving_transcription.py @@ -29,7 +29,7 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Union -from fastapi import Request +from fastapi import Request, WebSocket from fastapi.responses import ORJSONResponse, Response, StreamingResponse from sglang.srt.entrypoints.openai.protocol import ( @@ -45,6 +45,7 @@ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.streaming_asr import ( StreamingASRState, + process_asr_chunk, split_audio_chunks, ) from sglang.srt.entrypoints.openai.transcription_adapters import resolve_adapter @@ -272,7 +273,6 @@ async def _generate_chunked_asr_stream( - Token-level streaming within chunks (stream=True) - Encoder window caching across chunks - Cross-chunk KV cache reuse - - WebSocket endpoint for real-time audio input """ created_time = int(time.time()) request_id = f"{self._request_id_prefix()}{uuid.uuid4().hex}" @@ -288,43 +288,18 @@ async def _generate_chunked_asr_stream( logger.info("[streaming_asr] client disconnected, stopping") break is_last = i == len(chunks) - 1 - prompt = self._adapter.prompt_template + state.get_prefix_text() - chunk_request = GenerateReqInput( - text=prompt, + delta = await process_asr_chunk( + tokenizer_manager=self.tokenizer_manager, + adapter=self._adapter, + state=state, audio_data=chunk_audio, sampling_params=adapted_request.sampling_params, - stream=False, - modalities=["audio"], + is_last=is_last, + raw_request=raw_request, routing_key=self.extract_routing_key(raw_request), ) - try: - ret = None - async for ret in self.tokenizer_manager.generate_request( - chunk_request, raw_request - ): - break - except asyncio.CancelledError: - raise - except ValueError as e: - logger.warning( - "[streaming_asr] chunk %d failed with ValueError: %s", i, e - ) - continue - - if ret is None: - logger.warning("[streaming_asr] empty response for chunk %d", i) - continue - - text = self._adapter.postprocess_text(ret.get("text", "")) - - if is_last: - state.full_transcript = text - delta = state.finalize() - else: - delta = state.update(text) - if delta: for word in delta.split(" "): if not word: @@ -366,3 +341,11 @@ async def _generate_chunked_asr_stream( yield f"data: {error}\n\n" yield "data: [DONE]\n\n" + + async def handle_websocket(self, websocket: WebSocket) -> None: + """Handle a Realtime transcription session over WebSocket.""" + from sglang.srt.entrypoints.openai.serving_transcription_websocket import ( + handle_realtime_transcription, + ) + + await handle_realtime_transcription(websocket, self) diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription_websocket.py b/python/sglang/srt/entrypoints/openai/serving_transcription_websocket.py new file mode 100644 index 000000000000..c3345ff317a1 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_transcription_websocket.py @@ -0,0 +1,350 @@ +"""WebSocket transport for OpenAI Realtime API-style transcription. + +The wire protocol mirrors OpenAI's Realtime API conventions +(``session.start`` / ``transcript.delta`` / ``transcript.final``) so the +``Realtime*`` symbol prefix refers to the protocol identity, not the +transport. A future gRPC streaming variant for the same OpenAI-style +protocol would live in ``serving_transcription_grpc.py`` and could reuse +the same enum/class names without collision. +""" + +import io +import json +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + +import numpy as np +import soundfile as sf +from fastapi import WebSocket, WebSocketDisconnect + +from sglang.srt.entrypoints.openai.protocol import TranscriptionRequest +from sglang.srt.entrypoints.openai.streaming_asr import ( + StreamingASRState, + _normalize_whitespace, + process_asr_chunk, +) +from sglang.srt.entrypoints.websocket_base import WebSocketSessionBase + +logger = logging.getLogger(__name__) + +# Realtime transcription protocol-fixed audio format: PCM16, 16 kHz, mono, LE. +_SAMPLE_RATE = 16000 +_SAMPLE_WIDTH = 2 # bytes per sample (int16) +_BYTES_PER_SECOND = _SAMPLE_RATE * _SAMPLE_WIDTH + + +def _pcm_to_wav(pcm_buffer: bytes) -> bytes: + """Wrap raw PCM16 mono 16 kHz bytes in a WAV container so + ``soundfile.read`` (called by the multimodal processor) can decode it. + """ + if not pcm_buffer: + raise ValueError("pcm_buffer is empty") + samples = np.frombuffer(pcm_buffer, dtype=np.int16) + buf = io.BytesIO() + sf.write(buf, samples, _SAMPLE_RATE, format="WAV") + return buf.getvalue() + + +class RealtimeMessageType(str, Enum): + SESSION_START = "session.start" + SESSION_END = "session.end" + SESSION_STARTED = "session.started" + TRANSCRIPT_DELTA = "transcript.delta" + TRANSCRIPT_FINAL = "transcript.final" + ERROR = "error" + + +class RealtimeErrorCode(str, Enum): + UNSUPPORTED_MODEL = "unsupported_model" + INVALID_STATE = "invalid_state" + INVALID_JSON = "invalid_json" + INVALID_PAYLOAD = "invalid_payload" + INVALID_AUDIO_FORMAT = "invalid_audio_format" + UNKNOWN_MESSAGE = "unknown_message" + BUFFER_OVERFLOW = "buffer_overflow" + INTERNAL_ERROR = "internal_error" + + +@dataclass(kw_only=True) +class RealtimeTranscriptionSession(WebSocketSessionBase): + """OpenAI Realtime API-style session for live transcription over WebSocket.""" + + state: StreamingASRState + chunk_size_bytes: int + max_buffer_bytes: int + max_buffer_seconds: int + pcm_buffer: bytearray = field(default_factory=bytearray) + last_inference_offset: int = 0 + total_audio_bytes: int = 0 + started: bool = False + emitted_words: List[str] = field(default_factory=list) + sampling_params: Optional[dict] = None + model: Optional[str] = None + language: Optional[str] = None + + @property + def has_new_audio(self) -> bool: + return len(self.pcm_buffer) > self.last_inference_offset + + @property + def should_trigger_inference(self) -> bool: + return ( + len(self.pcm_buffer) - self.last_inference_offset >= self.chunk_size_bytes + ) + + def mark_inferred(self) -> None: + self.last_inference_offset = len(self.pcm_buffer) + + def duration_sec(self) -> float: + return round(self.total_audio_bytes / _BYTES_PER_SECOND, 2) + + async def send_error(self, code: RealtimeErrorCode, message: str) -> None: + """Send an OpenAI Realtime-style flat error event.""" + await self.send_json( + {"type": RealtimeMessageType.ERROR, "code": code, "message": message} + ) + + +async def handle_realtime_transcription(websocket: WebSocket, serving) -> None: + """Handle a Realtime transcription session over WebSocket. + + Single-task: receive and inference share one coroutine; PCM queues in OS + buffers during inference (capped by ``asr_max_buffer_seconds``). + ``session.end`` is therefore serialized after any in-flight chunk. + """ + session = await _init_session(websocket, serving) + if session is None: + return + try: + await _run_session_loop(serving, session) + except WebSocketDisconnect: + logger.info( + "[realtime_transcription] client disconnected: %s", session.session_id + ) + except Exception: + logger.exception( + "[realtime_transcription] unrecoverable error: %s", session.session_id + ) + try: + await session.send_error( + RealtimeErrorCode.INTERNAL_ERROR, "Internal server error" + ) + except (WebSocketDisconnect, RuntimeError): + pass + finally: + await session.safe_close() + + +async def _init_session( + websocket: WebSocket, serving +) -> Optional[RealtimeTranscriptionSession]: + """Construct and accept the session. Returns None if the adapter rejects.""" + adapter = serving._adapter + if not adapter.supports_chunked_streaming: + # Pre-session error: accept just to deliver the error, then close. + await websocket.accept() + await websocket.send_text( + json.dumps( + { + "type": RealtimeMessageType.ERROR, + "code": RealtimeErrorCode.UNSUPPORTED_MODEL, + "message": "Model does not support streaming ASR", + } + ) + ) + await websocket.close() + return None + + state = StreamingASRState(**adapter.chunked_streaming_config) + max_buffer_seconds = serving.tokenizer_manager.server_args.asr_max_buffer_seconds + session = RealtimeTranscriptionSession( + websocket=websocket, + state=state, + chunk_size_bytes=int(state.chunk_size_sec * _BYTES_PER_SECOND), + max_buffer_bytes=max_buffer_seconds * _BYTES_PER_SECOND, + max_buffer_seconds=max_buffer_seconds, + ) + await session.accept() + return session + + +async def _run_session_loop(serving, session: RealtimeTranscriptionSession) -> None: + """Main receive/dispatch loop. Returns when session should terminate.""" + while True: + message = await session.websocket.receive() + if message["type"] == "websocket.disconnect": + return + + text = message.get("text") + data = message.get("bytes") + if text: + if await _handle_control_message(serving, session, text): + return + elif data: + if await _handle_audio_frame(serving, session, data): + return + + +async def _handle_control_message( + serving, session: RealtimeTranscriptionSession, text: str +) -> bool: + """Process a JSON control message. Returns True if the session should end.""" + try: + ctrl = json.loads(text) + except json.JSONDecodeError: + await session.send_error(RealtimeErrorCode.INVALID_JSON, "Invalid JSON") + return False + if not isinstance(ctrl, dict): + await session.send_error( + RealtimeErrorCode.INVALID_PAYLOAD, + "Control message must be a JSON object", + ) + return False + + msg_type = ctrl.get("type", "") + if msg_type == RealtimeMessageType.SESSION_START: + await _handle_session_start(serving, session, ctrl) + return False + if msg_type == RealtimeMessageType.SESSION_END: + await _handle_session_end(serving, session) + return True # session.end always terminates the loop + + await session.send_error( + RealtimeErrorCode.UNKNOWN_MESSAGE, f"Unknown message type: {msg_type}" + ) + return False + + +async def _handle_session_start( + serving, session: RealtimeTranscriptionSession, ctrl: dict +) -> None: + if session.started: + await session.send_error( + RealtimeErrorCode.INVALID_STATE, "Session already started" + ) + return + + raw_model = ctrl.get("model") + raw_language = ctrl.get("language") + if raw_model is not None and not isinstance(raw_model, str): + await session.send_error( + RealtimeErrorCode.INVALID_PAYLOAD, + "session.start.model must be a string", + ) + return + if raw_language is not None and not isinstance(raw_language, str): + await session.send_error( + RealtimeErrorCode.INVALID_PAYLOAD, + "session.start.language must be a string", + ) + return + + session.model = raw_model + session.language = raw_language + adapter = serving._adapter + session.sampling_params = adapter.build_sampling_params( + TranscriptionRequest(language=raw_language) + if raw_language + else TranscriptionRequest() + ) + session.started = True + await session.send_json( + { + "type": RealtimeMessageType.SESSION_STARTED, + "session_id": session.session_id, + "model": session.model, + } + ) + + +async def _handle_session_end(serving, session: RealtimeTranscriptionSession) -> None: + if not session.started: + await session.send_error(RealtimeErrorCode.INVALID_STATE, "No active session") + return + + if session.has_new_audio: + await _run_inference(serving, session, is_last=True) + elif session.state.full_transcript: + # No new audio but update() held back the unfixed tail; flush via + # finalize() without another inference. + tail = session.state.finalize() + await _emit_delta(session, tail) + + # Re-normalize: batched inference can split punctuation into its own + # word (["him", ","]), making " ".join(...) leak orphan spaces. + await session.send_json( + { + "type": RealtimeMessageType.TRANSCRIPT_FINAL, + "text": _normalize_whitespace(" ".join(session.emitted_words)), + "duration_sec": session.duration_sec(), + "model": session.model, + } + ) + + +async def _handle_audio_frame( + serving, session: RealtimeTranscriptionSession, data: bytes +) -> bool: + """Append an audio frame and maybe trigger inference. Returns True on overflow.""" + if not session.started: + await session.send_error( + RealtimeErrorCode.INVALID_STATE, + "Send session.start before streaming audio", + ) + return False + if len(data) % _SAMPLE_WIDTH != 0: + await session.send_error( + RealtimeErrorCode.INVALID_AUDIO_FORMAT, + f"PCM16 frame length must be a multiple of {_SAMPLE_WIDTH} bytes", + ) + return False + + session.pcm_buffer.extend(data) + session.total_audio_bytes += len(data) + + if len(session.pcm_buffer) > session.max_buffer_bytes: + await session.send_error( + RealtimeErrorCode.BUFFER_OVERFLOW, + f"Accumulated audio exceeded {session.max_buffer_seconds}s; " + "client is sending faster than inference can keep up", + ) + return True + + # Cumulative buffer: each inference sees all audio so far, + # but trigger only once per chunk_size of new audio. + if session.should_trigger_inference: + await _run_inference(serving, session, is_last=False) + return False + + +async def _run_inference( + serving, session: RealtimeTranscriptionSession, *, is_last: bool +) -> None: + wav_data = _pcm_to_wav(bytes(session.pcm_buffer)) + delta = await process_asr_chunk( + tokenizer_manager=serving.tokenizer_manager, + adapter=serving._adapter, + state=session.state, + audio_data=wav_data, + sampling_params=session.sampling_params, + is_last=is_last, + ) + session.mark_inferred() + await _emit_delta(session, delta) + + +async def _emit_delta(session: RealtimeTranscriptionSession, delta: str) -> None: + if not delta: + return + for word in delta.split(" "): + if not word: + continue + session.emitted_words.append(word) + await session.send_json( + { + "type": RealtimeMessageType.TRANSCRIPT_DELTA, + "delta": word, + } + ) diff --git a/python/sglang/srt/entrypoints/openai/streaming_asr.py b/python/sglang/srt/entrypoints/openai/streaming_asr.py index 77a808b23bc1..d24c6567a89e 100644 --- a/python/sglang/srt/entrypoints/openai/streaming_asr.py +++ b/python/sglang/srt/entrypoints/openai/streaming_asr.py @@ -1,9 +1,21 @@ +import asyncio import io +import logging +import re from dataclasses import dataclass -from typing import List +from typing import List, Optional import soundfile as sf +from sglang.srt.managers.io_struct import GenerateReqInput + +logger = logging.getLogger(__name__) + + +# Collapse whitespace before punctuation so batched-inference token +# boundary jitter (" ," vs ",") doesn't leak into deltas. +_PUNCT_WS_RE = re.compile(r"\s+([,.;:!?])") + @dataclass class StreamingASRState: @@ -22,13 +34,23 @@ class StreamingASRState: unfixed_chunk_num: int unfixed_token_num: int confirmed_text: str = "" + # Monotonic accumulator; used as prompt prefix so the model sees a + # natural continuation point, not the rolled-back ``confirmed_text``. + emitted_text: str = "" full_transcript: str = "" chunk_index: int = 0 def get_prefix_text(self) -> str: - if self.chunk_index < self.unfixed_chunk_num or not self.confirmed_text: + if self.chunk_index < self.unfixed_chunk_num or not self.emitted_text: return "" - return self.confirmed_text + return self.emitted_text + + def _record_emit(self, delta: str) -> str: + if delta: + self.emitted_text = ( + f"{self.emitted_text} {delta}".strip() if self.emitted_text else delta + ) + return delta def update(self, new_transcript: str) -> str: old_confirmed = self.confirmed_text @@ -40,7 +62,7 @@ def update(self, new_transcript: str) -> str: self.full_transcript = new_transcript self.chunk_index += 1 if self.confirmed_text.startswith(old_confirmed): - return self.confirmed_text[len(old_confirmed) :].strip() + return self._record_emit(self.confirmed_text[len(old_confirmed) :].strip()) # Model revised earlier text, use word level common prefix to avoid # re-emitting already-sent content and cutting mid-word. old_words = old_confirmed.split() @@ -50,7 +72,7 @@ def update(self, new_transcript: str) -> str: if ow != nw: break common_count += 1 - return " ".join(new_words[common_count:]) + return self._record_emit(" ".join(new_words[common_count:])) def finalize(self) -> str: confirmed_words = self.confirmed_text.split() @@ -64,8 +86,8 @@ def finalize(self) -> str: common_count += 1 self.confirmed_text = self.full_transcript if common_count == 0 and confirmed_words and all_words: - return self.full_transcript - return " ".join(all_words[common_count:]) + return self._record_emit(self.full_transcript) + return self._record_emit(" ".join(all_words[common_count:])) def split_audio_chunks(audio_data: bytes, chunk_size_sec: float) -> List[bytes]: @@ -91,3 +113,52 @@ def split_audio_chunks(audio_data: bytes, chunk_size_sec: float) -> List[bytes]: sf.write(buf, data[:end], sample_rate, format="WAV") chunks.append(buf.getvalue()) return chunks + + +def _normalize_whitespace(text: str) -> str: + return _PUNCT_WS_RE.sub(r"\1", text) + + +async def process_asr_chunk( + tokenizer_manager, + adapter, + state: StreamingASRState, + audio_data: bytes, + sampling_params: dict, + is_last: bool, + raw_request=None, + routing_key: Optional[str] = None, +) -> str: + """Run inference on one audio chunk. Shared by the HTTP and WebSocket paths.""" + prompt = adapter.prompt_template + state.get_prefix_text() + + chunk_request = GenerateReqInput( + text=prompt, + audio_data=audio_data, + sampling_params=sampling_params, + stream=False, + modalities=["audio"], + ) + if routing_key is not None: + chunk_request.routing_key = routing_key + + try: + ret = None + async for ret in tokenizer_manager.generate_request(chunk_request, raw_request): + break + except asyncio.CancelledError: + raise + except ValueError as e: + logger.warning("[streaming_asr] chunk %d failed: %s", state.chunk_index, e) + return "" + + if ret is None: + logger.warning("[streaming_asr] empty response for chunk %d", state.chunk_index) + return "" + + text = _normalize_whitespace(adapter.postprocess_text(ret.get("text", ""))) + + if is_last: + state.full_transcript = text + return state.finalize() + return state.update(text) diff --git a/python/sglang/srt/entrypoints/websocket_base.py b/python/sglang/srt/entrypoints/websocket_base.py new file mode 100644 index 000000000000..27243d54ddb3 --- /dev/null +++ b/python/sglang/srt/entrypoints/websocket_base.py @@ -0,0 +1,30 @@ +import json +import uuid +from dataclasses import dataclass, field + +from fastapi import WebSocket, WebSocketDisconnect + + +@dataclass(kw_only=True) +class WebSocketSessionBase: + """Minimal base for persistent WebSocket sessions. + + Provides JSON send / accept / safe close. Subclasses are responsible + for the receive loop, message dispatch, error event format, and any + protocol-specific state. + """ + + websocket: WebSocket + session_id: str = field(default_factory=lambda: f"sess_{uuid.uuid4().hex[:12]}") + + async def accept(self) -> None: + await self.websocket.accept() + + async def send_json(self, data: dict) -> None: + await self.websocket.send_text(json.dumps(data)) + + async def safe_close(self) -> None: + try: + await self.websocket.close() + except (WebSocketDisconnect, RuntimeError): + pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8cea237cc2d5..ed85c667220d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -444,6 +444,7 @@ class ServerArgs: tool_call_parser: Optional[str] = None tool_server: Optional[str] = None sampling_defaults: str = "model" + asr_max_buffer_seconds: int = 60 # Data parallelism dp_size: int = 1 @@ -4830,6 +4831,15 @@ def add_cli_args(parser: argparse.ArgumentParser): "'model' uses the model's generation_config.json to get the recommended " "sampling parameters if available. Default is 'model'.", ) + parser.add_argument( + "--asr-max-buffer-seconds", + type=int, + default=ServerArgs.asr_max_buffer_seconds, + help="Maximum seconds of PCM audio the streaming ASR WebSocket handler " + "will accumulate before closing the session with a buffer_overflow " + "error. Guards against OOM when a client streams audio faster than " + "inference can consume it. Default 60s.", + ) # Data parallelism parser.add_argument( diff --git a/test/manual/models/test_qwen3_asr.py b/test/manual/models/test_qwen3_asr.py index c0b772bf5a6e..875a8ef59ed0 100644 --- a/test/manual/models/test_qwen3_asr.py +++ b/test/manual/models/test_qwen3_asr.py @@ -1,17 +1,30 @@ """ Test Qwen3-ASR model support in SGLang. -Tests /v1/audio/transcriptions endpoint (OpenAI-compatible). +Tests /v1/audio/transcriptions (HTTP) and +/v1/audio/transcriptions/stream (WebSocket live audio input). Usage: python test/manual/models/test_qwen3_asr.py """ +import asyncio import io +import json import os +import re import unittest +import numpy as np import requests +import soundfile as sf + +try: + import websockets + + HAS_WEBSOCKETS = True +except ImportError: + HAS_WEBSOCKETS = False from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -29,8 +42,96 @@ TEST_AUDIO_ZH_URL = ( "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav" ) +TEST_AUDIO_MLK_URL = ( + "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac" +) +TEST_AUDIO_LIBRI_URL = ( + "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac" +) +TEST_AUDIO_SPANISH_URL = ( + "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/4.flac" +) +TEST_AUDIO_HINDI_URL = ( + "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/hindi.ogg" +) +TEST_AUDIO_MP3_URL = ( + "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/i-know-kung-fu.mp3" +) TEST_AUDIO_EN_LOCAL = "/tmp/test_qwen3_asr_en.wav" TEST_AUDIO_ZH_LOCAL = "/tmp/test_qwen3_asr_zh.wav" +TEST_AUDIO_MLK_LOCAL = "/tmp/test_qwen3_asr_mlk.flac" +TEST_AUDIO_LIBRI_LOCAL = "/tmp/test_qwen3_asr_libri.flac" +TEST_AUDIO_SPANISH_LOCAL = "/tmp/test_qwen3_asr_spanish.flac" +TEST_AUDIO_HINDI_LOCAL = "/tmp/test_qwen3_asr_hindi.ogg" +TEST_AUDIO_MP3_LOCAL = "/tmp/test_qwen3_asr_kungfu.mp3" + +# Captured from Qwen3-ASR-0.6B non-streaming inference (2026-04-14). +# Refresh if model weights or sampling params change. +EXPECTED_TRANSCRIPTS = { + "en": ( + "Oh yeah, yeah. He wasn't even that big when I started listening to him." + " But and his solo music didn't do overly well, but he did very well" + " when he started writing for other people." + ), + "zh": "甚至出现交易几乎停滞的情况。", + "mlk": ( + "I have a dream that one day this nation will rise up and live out" + " the true meaning of its creed." + ), + "libri": ( + "He hoped there would be stew for dinner—turnips and carrots and" + " bruised potatoes and fat mutton pieces—to be ladled out in thick" + " peppered flour-fatted sauce." + ), + "spanish": ( + "y en las ramas medio sumergidas revoloteaban algunos pájaros" + " de químico y legendario plumaje" + ), + "hindi": "मिर्ची में कितने विभिन्न प्रजातियाँ हैं", + "mp3": "I know kung fu.", +} + + +def _normalize_for_wer(text: str) -> list: + """Lowercase, strip punctuation, split on whitespace. + + Used by ``_wer`` so that chunked-streaming artifacts that differ from + one-shot only in punctuation / casing (``—`` vs ``:``, trailing period, + leading capitalization) don't count as errors. + """ + text = text.lower() + text = re.sub(r"[^\w\s\u0900-\u097f\u4e00-\u9fff]+", " ", text) + return text.split() + + +def _wer(hypothesis: str, reference: str) -> float: + """Word error rate via Levenshtein distance on normalized tokens. + + Returns ``edit_distance(hyp_words, ref_words) / len(ref_words)``. For + CJK text where ``str.split()`` degenerates, we fall back to character + distance so the metric still means something. + """ + hyp = _normalize_for_wer(hypothesis) + ref = _normalize_for_wer(reference) + if len(ref) <= 1 and not any(" " in w for w in ref): + # CJK fallback: compare at char level + hyp = list(hypothesis.replace(" ", "")) + ref = list(reference.replace(" ", "")) + if not ref: + return 0.0 if not hyp else float("inf") + # Standard Levenshtein DP + n, m = len(hyp), len(ref) + dp = list(range(m + 1)) + for i in range(1, n + 1): + prev, dp[0] = dp[0], i + for j in range(1, m + 1): + cur = dp[j] + if hyp[i - 1] == ref[j - 1]: + dp[j] = prev + else: + dp[j] = 1 + min(prev, dp[j - 1], dp[j]) + prev = cur + return dp[m] / len(ref) def download_audio(url, local_path): @@ -45,8 +146,78 @@ def download_audio(url, local_path): return resp.content +def _pcm16_from_audio_bytes(audio_bytes): + """Decode audio bytes, resample to 16kHz mono, return (pcm_bytes, sample_rate).""" + data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32") + if len(data.shape) > 1: + data = data.mean(axis=1) + target_sr = 16000 + if sr != target_sr: + num_samples = int(len(data) / sr * target_sr) + indices = np.linspace(0, len(data) - 1, num_samples) + data = np.interp(indices, np.arange(len(data)), data) + sr = target_sr + pcm = (data * 32767).astype(np.int16).tobytes() + return pcm, sr + + +async def _stream_websocket_async( + websocket_url, pcm_bytes, sample_rate, language=None, realtime=False +): + """Stream PCM over WebSocket; return {text, deltas, session_id, duration_sec}. + + If realtime=True, sleeps between chunks to simulate live audio pacing. + """ + chunk_duration = 0.5 + chunk_bytes = int(chunk_duration * sample_rate * 2) # int16 = 2 bytes + + async with websockets.connect(websocket_url) as websocket: + start_msg = {"type": "session.start"} + if language: + start_msg["language"] = language + await websocket.send(json.dumps(start_msg)) + ack = json.loads(await websocket.recv()) + assert ack.get("type") == "session.started", f"unexpected ack: {ack}" + session_id = ack["session_id"] + + deltas = [] + final_msg = {} + + async def receive_loop(): + async for raw in websocket: + resp = json.loads(raw) + if resp["type"] == "transcript.delta": + deltas.append(resp["delta"]) + elif resp["type"] == "transcript.final": + final_msg.update(resp) + return + elif resp["type"] == "error": + raise RuntimeError( + f"websocket error [{resp.get('code', '?')}]: {resp.get('message', '')}" + ) + + receiver = asyncio.create_task(receive_loop()) + + for offset in range(0, len(pcm_bytes), chunk_bytes): + chunk = pcm_bytes[offset : offset + chunk_bytes] + await websocket.send(chunk) + if realtime: + await asyncio.sleep(chunk_duration) + + await websocket.send(json.dumps({"type": "session.end"})) + await receiver + + assert final_msg, "no transcript.final received" + return { + "text": final_msg.get("text", ""), + "deltas": deltas, + "session_id": session_id, + "duration_sec": final_msg.get("duration_sec", 0.0), + } + + class TestQwen3ASRTranscription(CustomTestCase): - """Test Qwen3-ASR via /v1/audio/transcriptions endpoint.""" + """Test Qwen3-ASR via HTTP /v1/audio/transcriptions and WebSocket /v1/audio/transcriptions/stream.""" @classmethod def setUpClass(cls): @@ -67,8 +238,12 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) + # ------------------------------------------------------------------ + # HTTP path + # ------------------------------------------------------------------ + def _transcribe(self, audio_url, local_path, language=None): - """Send a transcription request.""" + """Send an HTTP transcription request.""" audio_bytes = download_audio(audio_url, local_path) data = {"model": "qwen3-asr"} if language: @@ -98,6 +273,50 @@ def test_chinese_transcription(self): self.assertTrue(len(text) > 0, "Transcription should not be empty") print(f"[ZH Transcription] {text}") + def test_mlk_transcription(self): + """13s MLK speech (FLAC 22050 Hz) — HTTP non-stream ground truth.""" + result = self._transcribe(TEST_AUDIO_MLK_URL, TEST_AUDIO_MLK_LOCAL) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[MLK Transcription] {text}") + + def test_librispeech_transcription(self): + """10s LibriSpeech-style FLAC 16 kHz — HTTP non-stream ground truth.""" + result = self._transcribe(TEST_AUDIO_LIBRI_URL, TEST_AUDIO_LIBRI_LOCAL) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[LibriSpeech Transcription] {text}") + + def test_spanish_transcription(self): + """Spanish FLAC 48 kHz PCM_24 — HTTP non-stream ground truth.""" + result = self._transcribe( + TEST_AUDIO_SPANISH_URL, TEST_AUDIO_SPANISH_LOCAL, language="es" + ) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[Spanish Transcription] {text}") + + def test_hindi_transcription(self): + """Hindi OGG/Opus 16 kHz — HTTP non-stream ground truth.""" + result = self._transcribe( + TEST_AUDIO_HINDI_URL, TEST_AUDIO_HINDI_LOCAL, language="hi" + ) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[Hindi Transcription] {text}") + + def test_mp3_stereo_transcription(self): + """MP3 stereo 44.1 kHz — HTTP non-stream ground truth.""" + result = self._transcribe(TEST_AUDIO_MP3_URL, TEST_AUDIO_MP3_LOCAL) + self.assertIn("text", result) + text = result["text"] + self.assertTrue(len(text) > 0, "Transcription should not be empty") + print(f"[MP3 Transcription] {text}") + def test_multiple_requests_consistency(self): """Test that repeated requests produce consistent output.""" results = [] @@ -113,6 +332,194 @@ def test_multiple_requests_consistency(self): ) print(f"[Consistency] All 3 requests match: {results[0][:80]}...") + # ------------------------------------------------------------------ + # WebSocket path + # ------------------------------------------------------------------ + + def _websocket_url(self): + return ( + self.base_url.replace("http://", "ws://").replace("https://", "wss://") + + "/v1/audio/transcriptions/stream" + ) + + def _stream_websocket(self, audio_url, local_path, language=None, realtime=False): + audio_bytes = download_audio(audio_url, local_path) + pcm, sr = _pcm16_from_audio_bytes(audio_bytes) + return asyncio.run( + _stream_websocket_async( + self._websocket_url(), pcm, sr, language=language, realtime=realtime + ) + ) + + def _assert_close_to_ref( + self, hypothesis: str, ref_key: str, max_wer: float = 0.15 + ): + """Assert a streamed transcript stays within ``max_wer`` of the reference. + + Chunked streaming inherits a few artifacts from #22089 that one-shot + does not — "Uh huh." short-context hallucination on long English + audio, mid-sentence punctuation drift (``—`` → ``:``), and trailing + periods. We accept up to 15% WER (normalized, case/punct stripped) + so these don't break CI, while still catching real regressions + like dropped words or double-emitted phrases. + """ + reference = EXPECTED_TRANSCRIPTS[ref_key] + wer = _wer(hypothesis, reference) + self.assertLessEqual( + wer, + max_wer, + f"WER {wer:.3f} > {max_wer} for {ref_key!r}\n" + f" hyp: {hypothesis!r}\n ref: {reference!r}", + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_english_websocket_streaming(self): + """Test English audio transcription over WebSocket (fast mode).""" + result = self._stream_websocket(TEST_AUDIO_EN_URL, TEST_AUDIO_EN_LOCAL) + self._assert_close_to_ref(result["text"], "en") + self.assertGreater(len(result["deltas"]), 0) + print( + f"[EN WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_chinese_websocket_streaming(self): + """Test Chinese audio transcription over WebSocket with session.start.language.""" + result = self._stream_websocket( + TEST_AUDIO_ZH_URL, TEST_AUDIO_ZH_LOCAL, language="zh" + ) + self._assert_close_to_ref(result["text"], "zh") + print( + f"[ZH WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_websocket_streaming_realtime(self): + """Exercise real-time pacing: sleep between chunks to simulate live audio.""" + result = self._stream_websocket( + TEST_AUDIO_EN_URL, TEST_AUDIO_EN_LOCAL, realtime=True + ) + self._assert_close_to_ref(result["text"], "en") + self.assertGreaterEqual( + len(result["deltas"]), 2, "realtime mode should yield multiple deltas" + ) + print( + f"[Realtime WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_mlk_speech_websocket_streaming(self): + """13s English MLK speech, FLAC @ 22050 Hz (exercises client-side resampling).""" + result = self._stream_websocket(TEST_AUDIO_MLK_URL, TEST_AUDIO_MLK_LOCAL) + self._assert_close_to_ref(result["text"], "mlk") + print( + f"[MLK WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_librispeech_dummy_websocket_streaming(self): + """10s LibriSpeech-style FLAC @ 16 kHz (no resampling needed).""" + result = self._stream_websocket(TEST_AUDIO_LIBRI_URL, TEST_AUDIO_LIBRI_LOCAL) + self._assert_close_to_ref(result["text"], "libri") + print( + f"[LibriSpeech WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_websocket_concurrent_sessions(self): + """3 parallel WebSocket sessions; verify state isolation and independent finals.""" + audio_bytes = download_audio(TEST_AUDIO_EN_URL, TEST_AUDIO_EN_LOCAL) + pcm, sr = _pcm16_from_audio_bytes(audio_bytes) + + async def run_n_concurrent(n): + return await asyncio.gather( + *[ + _stream_websocket_async(self._websocket_url(), pcm, sr) + for _ in range(n) + ] + ) + + results = asyncio.run(run_n_concurrent(3)) + + session_ids = {r["session_id"] for r in results} + self.assertEqual(len(session_ids), 3, "each session must have a unique id") + for i, r in enumerate(results): + self.assertTrue( + len(r["text"]) > 0, f"session {i} should produce non-empty transcript" + ) + # All sessions ran the same audio, so finals should match. + finals = [r["text"] for r in results] + self.assertEqual( + len(set(finals)), + 1, + f"3 concurrent sessions on identical audio should yield identical finals, got {finals}", + ) + print( + f"[Concurrent x3 WS] all finals match: {finals[0]} " + f"(session_ids={sorted(session_ids)})" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_spanish_websocket_streaming(self): + """6.6s Spanish audio, FLAC @ 48 kHz PCM_24 (resampling + high-bit-depth).""" + result = self._stream_websocket( + TEST_AUDIO_SPANISH_URL, TEST_AUDIO_SPANISH_LOCAL, language="es" + ) + self._assert_close_to_ref(result["text"], "spanish") + print( + f"[Spanish WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_hindi_websocket_streaming(self): + """4s Hindi audio, OGG/Opus @ 16 kHz (multilingual + non-WAV/FLAC container).""" + result = self._stream_websocket( + TEST_AUDIO_HINDI_URL, TEST_AUDIO_HINDI_LOCAL, language="hi" + ) + self._assert_close_to_ref(result["text"], "hindi") + print( + f"[Hindi WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_mp3_stereo_websocket_streaming(self): + """4s English MP3 stereo @ 44.1 kHz (mp3 decode + stereo->mono + resample).""" + result = self._stream_websocket(TEST_AUDIO_MP3_URL, TEST_AUDIO_MP3_LOCAL) + self._assert_close_to_ref(result["text"], "mp3") + print( + f"[MP3 stereo WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + + @unittest.skipUnless(HAS_WEBSOCKETS, "websockets package not installed") + def test_websocket_short_clip(self): + """Sub-chunk clip: less than chunk_size_sec of real speech. + + Qwen3-ASR hallucinates badly on <2s inputs (short-context artifact), + so we use ~3s of the MP3 kungfu clip where speech actually starts. + This path still exercises the "session.end before any inference + trigger" branch because 3s < the session's accumulated-audio + threshold path when PCM is streamed in 0.5s frames. + """ + audio_bytes = download_audio(TEST_AUDIO_MP3_URL, TEST_AUDIO_MP3_LOCAL) + full_pcm, sr = _pcm16_from_audio_bytes(audio_bytes) + short_pcm = full_pcm[: sr * 2 * 3] # first 3 seconds + result = asyncio.run( + _stream_websocket_async(self._websocket_url(), short_pcm, sr) + ) + self._assert_close_to_ref(result["text"], "mp3") + print( + f"[Short clip WS] final={result['text']} " + f"({len(result['deltas'])} deltas, {result['duration_sec']}s)" + ) + if __name__ == "__main__": unittest.main(verbosity=3)