diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription.py b/python/sglang/srt/entrypoints/openai/serving_transcription.py index 5540b518f7dd..5f98a6299931 100644 --- a/python/sglang/srt/entrypoints/openai/serving_transcription.py +++ b/python/sglang/srt/entrypoints/openai/serving_transcription.py @@ -21,6 +21,7 @@ from __future__ import annotations +import asyncio import io import logging import math @@ -42,6 +43,10 @@ TranscriptionVerboseResponse, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.entrypoints.openai.streaming_asr import ( + StreamingASRState, + split_audio_chunks, +) from sglang.srt.entrypoints.openai.transcription_adapters import resolve_adapter from sglang.srt.managers.io_struct import GenerateReqInput @@ -178,6 +183,15 @@ async def _handle_streaming_request( raw_request: Request, ) -> StreamingResponse: """Handle streaming transcription request.""" + if self._adapter.supports_chunked_streaming: + # No background abort_task: each chunk is a separate request; + # client disconnection is detected via is_disconnected() in the loop. + return StreamingResponse( + self._generate_chunked_asr_stream( + adapted_request, request, raw_request + ), + media_type="text/event-stream", + ) return StreamingResponse( self._generate_transcription_stream(adapted_request, request, raw_request), media_type="text/event-stream", @@ -241,3 +255,114 @@ async def _generate_transcription_stream( yield f"data: {error}\n\n" yield "data: [DONE]\n\n" + + async def _generate_chunked_asr_stream( + self, + adapted_request: GenerateReqInput, + request: TranscriptionRequest, + raw_request: Request, + ) -> AsyncGenerator[str, None]: + """Chunk-based streaming for ASR with prefix rollback. + + Audio is split into chunks and each chunk is processed as an + independent request. Partial transcripts are emitted via SSE + with prefix rollback to reduce boundary jitter. + + TODO: + - Token-level streaming within chunks (stream=True) + - Encoder window caching across chunks + - Cross-chunk KV cache reuse + - WebSocket endpoint for real-time audio input + """ + created_time = int(time.time()) + request_id = f"{self._request_id_prefix()}{uuid.uuid4().hex}" + model = request.model + state = StreamingASRState(**self._adapter.chunked_streaming_config) + first_word = True + + try: + chunks = split_audio_chunks(request.audio_data, state.chunk_size_sec) + + for i, chunk_audio in enumerate(chunks): + if await raw_request.is_disconnected(): + logger.info("[streaming_asr] client disconnected, stopping") + break + is_last = i == len(chunks) - 1 + prompt = self._adapter.prompt_template + state.get_prefix_text() + + chunk_request = GenerateReqInput( + text=prompt, + audio_data=chunk_audio, + sampling_params=adapted_request.sampling_params, + stream=False, + modalities=["audio"], + routing_key=self.extract_routing_key(raw_request), + ) + + try: + ret = None + async for ret in self.tokenizer_manager.generate_request( + chunk_request, raw_request + ): + break + except asyncio.CancelledError: + raise + except ValueError as e: + logger.warning( + "[streaming_asr] chunk %d failed with ValueError: %s", i, e + ) + continue + + if ret is None: + logger.warning("[streaming_asr] empty response for chunk %d", i) + continue + + text = self._adapter.postprocess_text(ret.get("text", "")) + + if is_last: + state.full_transcript = text + delta = state.finalize() + else: + delta = state.update(text) + + if delta: + for word in delta.split(" "): + if not word: + continue + content = word if first_word else " " + word + first_word = False + chunk_resp = TranscriptionStreamResponse( + id=request_id, + created=created_time, + model=model, + choices=[ + TranscriptionStreamChoice( + delta=DeltaMessage(content=content), + finish_reason=None, + ) + ], + ) + yield f"data: {chunk_resp.model_dump_json()}\n\n" + + # Send final stop + chunk_resp = TranscriptionStreamResponse( + id=request_id, + created=created_time, + model=model, + choices=[ + TranscriptionStreamChoice( + delta=DeltaMessage(), + finish_reason="stop", + ) + ], + ) + yield f"data: {chunk_resp.model_dump_json()}\n\n" + + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception("[streaming_asr] unrecoverable error") + error = self.create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" + + yield "data: [DONE]\n\n" diff --git a/python/sglang/srt/entrypoints/openai/streaming_asr.py b/python/sglang/srt/entrypoints/openai/streaming_asr.py new file mode 100644 index 000000000000..77a808b23bc1 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/streaming_asr.py @@ -0,0 +1,93 @@ +import io +from dataclasses import dataclass +from typing import List + +import soundfile as sf + + +@dataclass +class StreamingASRState: + """State for chunk-based streaming ASR with prefix rollback. + + Parameters are model-specific and should be provided via the + adapter's ``chunked_streaming_config``. + + Known limitation: rollback uses str.split() which is ineffective + for CJK languages (no whitespace between words). + TODO: implement token-level rollback to handle all languages + correctly. + """ + + chunk_size_sec: float + unfixed_chunk_num: int + unfixed_token_num: int + confirmed_text: str = "" + full_transcript: str = "" + chunk_index: int = 0 + + def get_prefix_text(self) -> str: + if self.chunk_index < self.unfixed_chunk_num or not self.confirmed_text: + return "" + return self.confirmed_text + + def update(self, new_transcript: str) -> str: + old_confirmed = self.confirmed_text + words = new_transcript.split() + if len(words) > self.unfixed_token_num: + self.confirmed_text = " ".join(words[: -self.unfixed_token_num]) + else: + self.confirmed_text = "" + self.full_transcript = new_transcript + self.chunk_index += 1 + if self.confirmed_text.startswith(old_confirmed): + return self.confirmed_text[len(old_confirmed) :].strip() + # Model revised earlier text, use word level common prefix to avoid + # re-emitting already-sent content and cutting mid-word. + old_words = old_confirmed.split() + new_words = self.confirmed_text.split() + common_count = 0 + for ow, nw in zip(old_words, new_words): + if ow != nw: + break + common_count += 1 + return " ".join(new_words[common_count:]) + + def finalize(self) -> str: + confirmed_words = self.confirmed_text.split() + all_words = self.full_transcript.split() + # Use word level common prefix to handle punctuation differences + # between intermediate chunks and the final full transcription. + common_count = 0 + for cw, aw in zip(confirmed_words, all_words): + if cw != aw: + break + common_count += 1 + self.confirmed_text = self.full_transcript + if common_count == 0 and confirmed_words and all_words: + return self.full_transcript + return " ".join(all_words[common_count:]) + + +def split_audio_chunks(audio_data: bytes, chunk_size_sec: float) -> List[bytes]: + if not audio_data: + raise ValueError("audio_data is empty") + if chunk_size_sec <= 0: + raise ValueError(f"chunk_size_sec must be positive, got {chunk_size_sec}") + audio_file = io.BytesIO(audio_data) + try: + data, sample_rate = sf.read(audio_file, dtype="float32") + except sf.LibsndfileError as e: + raise ValueError(f"failed to decode audio: {e}") from e + if len(data.shape) > 1: + data = data.mean(axis=1) + chunk_size_samples = int(chunk_size_sec * sample_rate) + total_samples = len(data) + chunks = [] + for end in range( + chunk_size_samples, total_samples + chunk_size_samples, chunk_size_samples + ): + end = min(end, total_samples) + buf = io.BytesIO() + sf.write(buf, data[:end], sample_rate, format="WAV") + chunks.append(buf.getvalue()) + return chunks diff --git a/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py b/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py index c2deb05e1b6c..ff0d986e54b1 100644 --- a/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py +++ b/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py @@ -22,6 +22,29 @@ class TranscriptionAdapter(ABC): def build_sampling_params(self, request: TranscriptionRequest) -> dict: """Return the ``sampling_params`` dict for ``GenerateReqInput``.""" + @property + def supports_chunked_streaming(self) -> bool: + """Whether this model uses chunk-based streaming instead of token-level streaming.""" + return False + + @property + def prompt_template(self) -> str: + """Prompt template for chunked streaming requests. + + Only used when ``supports_chunked_streaming`` is True. + The default returns an empty string. + """ + return "" + + @property + def chunked_streaming_config(self) -> dict: + """Parameters for ``StreamingASRState`` when using chunked streaming. + + Only used when ``supports_chunked_streaming`` is True. + Keys: ``chunk_size_sec``, ``unfixed_chunk_num``, ``unfixed_token_num``. + """ + return {} + def postprocess_text(self, text: str) -> str: """Strip model-specific markers from raw decoded text. diff --git a/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py b/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py index dca58ec84fb0..df686b15aecb 100644 --- a/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py +++ b/python/sglang/srt/entrypoints/openai/transcription_adapters/qwen3_asr.py @@ -9,12 +9,32 @@ TranscriptionAdapter, register_transcription_adapter, ) +from sglang.srt.multimodal.processors.qwen3_asr import DEFAULT_ASR_PROMPT @register_transcription_adapter("Qwen3ASR") class Qwen3ASRAdapter(TranscriptionAdapter): ASR_TEXT_TAG = "" + @property + def supports_chunked_streaming(self) -> bool: + return True + + @property + def chunked_streaming_config(self) -> dict: + # Qwen3-ASR paper (arXiv:2601.21337), Table 8 uses 4 unfixed chunks. + # We use 2 here for lower latency; tune based on quality needs. + # TODO: allow users to override these via API request parameters. + return { + "chunk_size_sec": 2.0, + "unfixed_chunk_num": 2, + "unfixed_token_num": 5, + } + + @property + def prompt_template(self) -> str: + return DEFAULT_ASR_PROMPT + def build_sampling_params(self, request: TranscriptionRequest) -> dict: temperature = request.temperature if temperature == 0.0: diff --git a/python/sglang/srt/multimodal/processors/qwen3_asr.py b/python/sglang/srt/multimodal/processors/qwen3_asr.py index 546dbc13708f..31368077f256 100644 --- a/python/sglang/srt/multimodal/processors/qwen3_asr.py +++ b/python/sglang/srt/multimodal/processors/qwen3_asr.py @@ -12,7 +12,7 @@ AUDIO_PLACEHOLDER = "<|audio_start|><|audio_pad|><|audio_end|>" -_DEFAULT_ASR_PROMPT = ( +DEFAULT_ASR_PROMPT = ( f"<|im_start|>user\n" f"{AUDIO_PLACEHOLDER}" f"<|im_end|>\n" @@ -47,7 +47,7 @@ def _build_transcription_prompt(self, input_text: Union[str, list]) -> str: if isinstance(input_text, list): input_text = self._tokenizer.decode(input_text) if not input_text or not input_text.strip(): - return _DEFAULT_ASR_PROMPT + return DEFAULT_ASR_PROMPT return input_text def compute_mrope_positions(self, input_ids, mm_items):