Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions python/sglang/srt/entrypoints/openai/serving_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from __future__ import annotations

import asyncio
import io
import logging
import math
Expand All @@ -42,6 +43,10 @@
TranscriptionVerboseResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.streaming_asr import (
StreamingASRState,
split_audio_chunks,
)
from sglang.srt.entrypoints.openai.transcription_adapters import resolve_adapter
from sglang.srt.managers.io_struct import GenerateReqInput

Expand Down Expand Up @@ -178,6 +183,15 @@ async def _handle_streaming_request(
raw_request: Request,
) -> StreamingResponse:
"""Handle streaming transcription request."""
if self._adapter.supports_chunked_streaming:
# No background abort_task: each chunk is a separate request;
# client disconnection is detected via is_disconnected() in the loop.
return StreamingResponse(
self._generate_chunked_asr_stream(
adapted_request, request, raw_request
),
media_type="text/event-stream",
)
return StreamingResponse(
self._generate_transcription_stream(adapted_request, request, raw_request),
media_type="text/event-stream",
Expand Down Expand Up @@ -241,3 +255,114 @@ async def _generate_transcription_stream(
yield f"data: {error}\n\n"

yield "data: [DONE]\n\n"

async def _generate_chunked_asr_stream(
self,
adapted_request: GenerateReqInput,
request: TranscriptionRequest,
raw_request: Request,
) -> AsyncGenerator[str, None]:
"""Chunk-based streaming for ASR with prefix rollback.

Audio is split into chunks and each chunk is processed as an
independent request. Partial transcripts are emitted via SSE
with prefix rollback to reduce boundary jitter.

TODO:
- Token-level streaming within chunks (stream=True)
- Encoder window caching across chunks
- Cross-chunk KV cache reuse
- WebSocket endpoint for real-time audio input
"""
created_time = int(time.time())
request_id = f"{self._request_id_prefix()}{uuid.uuid4().hex}"
model = request.model
state = StreamingASRState(**self._adapter.chunked_streaming_config)
first_word = True

try:
chunks = split_audio_chunks(request.audio_data, state.chunk_size_sec)

for i, chunk_audio in enumerate(chunks):
if await raw_request.is_disconnected():
logger.info("[streaming_asr] client disconnected, stopping")
break
is_last = i == len(chunks) - 1
prompt = self._adapter.prompt_template + state.get_prefix_text()

chunk_request = GenerateReqInput(
text=prompt,
audio_data=chunk_audio,
sampling_params=adapted_request.sampling_params,
stream=False,
modalities=["audio"],
routing_key=self.extract_routing_key(raw_request),
)

try:
ret = None
async for ret in self.tokenizer_manager.generate_request(
chunk_request, raw_request
):
break
except asyncio.CancelledError:
raise
except ValueError as e:
logger.warning(
"[streaming_asr] chunk %d failed with ValueError: %s", i, e
)
continue

if ret is None:
logger.warning("[streaming_asr] empty response for chunk %d", i)
continue

text = self._adapter.postprocess_text(ret.get("text", ""))

if is_last:
state.full_transcript = text
delta = state.finalize()
else:
delta = state.update(text)

if delta:
for word in delta.split(" "):
if not word:
continue
content = word if first_word else " " + word
first_word = False
chunk_resp = TranscriptionStreamResponse(
id=request_id,
created=created_time,
model=model,
choices=[
TranscriptionStreamChoice(
delta=DeltaMessage(content=content),
finish_reason=None,
)
],
)
yield f"data: {chunk_resp.model_dump_json()}\n\n"

# Send final stop
chunk_resp = TranscriptionStreamResponse(
id=request_id,
created=created_time,
model=model,
choices=[
TranscriptionStreamChoice(
delta=DeltaMessage(),
finish_reason="stop",
)
],
)
yield f"data: {chunk_resp.model_dump_json()}\n\n"

except asyncio.CancelledError:
raise
except Exception as e:
Comment thread
JustinTong0323 marked this conversation as resolved.
logger.exception("[streaming_asr] unrecoverable error")
error = self.create_streaming_error_response(str(e))
yield f"data: {error}\n\n"

yield "data: [DONE]\n\n"
93 changes: 93 additions & 0 deletions python/sglang/srt/entrypoints/openai/streaming_asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import io
from dataclasses import dataclass
from typing import List

import soundfile as sf


@dataclass
class StreamingASRState:
"""State for chunk-based streaming ASR with prefix rollback.

Parameters are model-specific and should be provided via the
adapter's ``chunked_streaming_config``.

Known limitation: rollback uses str.split() which is ineffective
for CJK languages (no whitespace between words).
TODO: implement token-level rollback to handle all languages
correctly.
"""

chunk_size_sec: float
unfixed_chunk_num: int
unfixed_token_num: int
confirmed_text: str = ""
full_transcript: str = ""
chunk_index: int = 0

def get_prefix_text(self) -> str:
if self.chunk_index < self.unfixed_chunk_num or not self.confirmed_text:
return ""
return self.confirmed_text

def update(self, new_transcript: str) -> str:
old_confirmed = self.confirmed_text
words = new_transcript.split()
if len(words) > self.unfixed_token_num:
self.confirmed_text = " ".join(words[: -self.unfixed_token_num])
else:
self.confirmed_text = ""
self.full_transcript = new_transcript
self.chunk_index += 1
if self.confirmed_text.startswith(old_confirmed):
return self.confirmed_text[len(old_confirmed) :].strip()
# Model revised earlier text, use word level common prefix to avoid
# re-emitting already-sent content and cutting mid-word.
old_words = old_confirmed.split()
new_words = self.confirmed_text.split()
common_count = 0
for ow, nw in zip(old_words, new_words):
if ow != nw:
break
common_count += 1
return " ".join(new_words[common_count:])

def finalize(self) -> str:
confirmed_words = self.confirmed_text.split()
all_words = self.full_transcript.split()
# Use word level common prefix to handle punctuation differences
# between intermediate chunks and the final full transcription.
common_count = 0
for cw, aw in zip(confirmed_words, all_words):
if cw != aw:
break
common_count += 1
self.confirmed_text = self.full_transcript
if common_count == 0 and confirmed_words and all_words:
return self.full_transcript
return " ".join(all_words[common_count:])


def split_audio_chunks(audio_data: bytes, chunk_size_sec: float) -> List[bytes]:
if not audio_data:
raise ValueError("audio_data is empty")
if chunk_size_sec <= 0:
raise ValueError(f"chunk_size_sec must be positive, got {chunk_size_sec}")
audio_file = io.BytesIO(audio_data)
try:
data, sample_rate = sf.read(audio_file, dtype="float32")
except sf.LibsndfileError as e:
raise ValueError(f"failed to decode audio: {e}") from e
if len(data.shape) > 1:
data = data.mean(axis=1)
chunk_size_samples = int(chunk_size_sec * sample_rate)
total_samples = len(data)
chunks = []
for end in range(
chunk_size_samples, total_samples + chunk_size_samples, chunk_size_samples
):
end = min(end, total_samples)
buf = io.BytesIO()
sf.write(buf, data[:end], sample_rate, format="WAV")
chunks.append(buf.getvalue())
return chunks
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,29 @@ class TranscriptionAdapter(ABC):
def build_sampling_params(self, request: TranscriptionRequest) -> dict:
"""Return the ``sampling_params`` dict for ``GenerateReqInput``."""

@property
def supports_chunked_streaming(self) -> bool:
"""Whether this model uses chunk-based streaming instead of token-level streaming."""
return False

@property
def prompt_template(self) -> str:
"""Prompt template for chunked streaming requests.

Only used when ``supports_chunked_streaming`` is True.
The default returns an empty string.
"""
return ""

@property
def chunked_streaming_config(self) -> dict:
"""Parameters for ``StreamingASRState`` when using chunked streaming.

Only used when ``supports_chunked_streaming`` is True.
Keys: ``chunk_size_sec``, ``unfixed_chunk_num``, ``unfixed_token_num``.
"""
return {}

def postprocess_text(self, text: str) -> str:
"""Strip model-specific markers from raw decoded text.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,32 @@
TranscriptionAdapter,
register_transcription_adapter,
)
from sglang.srt.multimodal.processors.qwen3_asr import DEFAULT_ASR_PROMPT


@register_transcription_adapter("Qwen3ASR")
class Qwen3ASRAdapter(TranscriptionAdapter):
ASR_TEXT_TAG = "<asr_text>"

@property
def supports_chunked_streaming(self) -> bool:
return True

@property
def chunked_streaming_config(self) -> dict:
# Qwen3-ASR paper (arXiv:2601.21337), Table 8 uses 4 unfixed chunks.
# We use 2 here for lower latency; tune based on quality needs.
# TODO: allow users to override these via API request parameters.
return {
"chunk_size_sec": 2.0,
"unfixed_chunk_num": 2,
"unfixed_token_num": 5,
}

@property
def prompt_template(self) -> str:
return DEFAULT_ASR_PROMPT

def build_sampling_params(self, request: TranscriptionRequest) -> dict:
temperature = request.temperature
if temperature == 0.0:
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/multimodal/processors/qwen3_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

AUDIO_PLACEHOLDER = "<|audio_start|><|audio_pad|><|audio_end|>"

_DEFAULT_ASR_PROMPT = (
DEFAULT_ASR_PROMPT = (
f"<|im_start|>user\n"
f"{AUDIO_PLACEHOLDER}"
f"<|im_end|>\n"
Expand Down Expand Up @@ -47,7 +47,7 @@ def _build_transcription_prompt(self, input_text: Union[str, list]) -> str:
if isinstance(input_text, list):
input_text = self._tokenizer.decode(input_text)
if not input_text or not input_text.strip():
return _DEFAULT_ASR_PROMPT
return DEFAULT_ASR_PROMPT
return input_text

def compute_mrope_positions(self, input_ids, mm_items):
Expand Down
Loading