diff --git a/python/sglang/srt/entrypoints/openai/realtime/session.py b/python/sglang/srt/entrypoints/openai/realtime/session.py index c5951993e25d..ffcf01293119 100644 --- a/python/sglang/srt/entrypoints/openai/realtime/session.py +++ b/python/sglang/srt/entrypoints/openai/realtime/session.py @@ -9,7 +9,6 @@ from __future__ import annotations import asyncio -import io import json import logging import math @@ -18,7 +17,6 @@ import numpy as np import pybase64 -import soundfile as sf from fastapi import WebSocket, WebSocketDisconnect from openai.types.realtime import ( ConversationItemCreatedEvent, @@ -83,6 +81,15 @@ _SAMPLE_WIDTH = 2 +def _slice_pcm_from(buffer, start: int) -> bytes: + """Immutable snapshot of ``buffer[start:]`` via memoryview — one slice-sized + copy instead of the two ``bytes(bytearray)[start:]`` would do. ``buffer`` is + bytes or bytearray. Raises instead of silently returning empty.""" + if not (0 <= start <= len(buffer)): + raise ValueError(f"_slice_pcm_from: start={start} not in [0, {len(buffer)}]") + return bytes(memoryview(buffer)[start:]) + + def _resample_to_target_rate(pcm: bytes, src_rate: int, target_rate: int) -> bytes: if src_rate == target_rate or not pcm: return pcm @@ -99,11 +106,10 @@ def _resample_to_target_rate(pcm: bytes, src_rate: int, target_rate: int) -> byt return (np.clip(samples, -1.0, 1.0) * 32767.0).astype(np.int16).tobytes() -def _pcm_to_wav(pcm: bytes, sample_rate: int) -> bytes: - samples = np.frombuffer(pcm, dtype=np.int16) - buf = io.BytesIO() - sf.write(buf, samples, sample_rate, format="WAV") - return buf.getvalue() +def _pcm_to_float_samples(pcm: bytes) -> np.ndarray: + # /32768.0 matches soundfile.read's default int16 normalization so the + # samples are bit-equal to the prior PCM→WAV→sf.read path. + return np.frombuffer(pcm, dtype=np.int16).astype(np.float64) / 32768.0 _CLIENT_EVENT_TYPES: Dict[str, type] = { @@ -139,17 +145,29 @@ class _SessionConfig: @dataclass class _AudioState: - """Per-item audio state: PCM buffer accumulated from - input_audio_buffer.append, the chunked ASR rollback state, and the - static buffer-size limits set at __init__. pcm_buffer / state / - last_inference_offset reset on commit-roll and clear; the size limits - stay constant for the session's lifetime.""" + """Per-item audio state. Once the slicing gate is reached (``state.emitted_text`` + non-empty AND ``state.chunk_index >= slicing_min_chunk_index``), inference + switches from the cumulative buffer to a tail slice at + ``pcm_buffer[committed_audio_until_bytes - left_overlap_bytes:]``. The FIRST + gated call still starts at offset 0 because ``committed_audio_until_bytes`` is + initialized to 0; only subsequent calls are bounded to the left overlap plus + newly appended audio. ``emitted_text`` is not injected into the prompt — the + retained acoustic overlap plus output-side dedupe takes the place of a + continuation prefix.""" max_buffer_bytes: int chunk_size_bytes: int + left_overlap_bytes: int + slicing_min_chunk_index: int state: StreamingASRState + # False when left_overlap covers the whole unfixed-chunk window, which + # leaves the K-unfixed dedupe target unreachable; flipped at session + # construction. When False, _run_inference always takes the cumulative + # path even after emitted_text becomes non-empty. + slicing_enabled: bool = True pcm_buffer: bytearray = field(default_factory=bytearray) last_inference_offset: int = 0 + committed_audio_until_bytes: int = 0 @dataclass @@ -190,6 +208,11 @@ def __init__( self.config = _SessionConfig() + slicing_cfg = adapter.realtime_slicing_config + left_overlap_ms = int(slicing_cfg["left_overlap_ms"]) + min_audio_sec = float(slicing_cfg["min_audio_sec"]) + left_overlap_bytes = int(left_overlap_ms / 1000 * self.bytes_per_second) + state = StreamingASRState(**adapter.chunked_streaming_config) chunk_size_bytes = int(state.chunk_size_sec * self.bytes_per_second) if chunk_size_bytes <= 0: @@ -197,10 +220,24 @@ def __init__( f"adapter.chunked_streaming_config produced non-positive " f"chunk_size_sec; got {state.chunk_size_sec!r}" ) + slicing_min_chunk_index = math.ceil(min_audio_sec / state.chunk_size_sec) + slicing_enabled = ( + left_overlap_bytes < state.unfixed_chunk_num * chunk_size_bytes + ) + if not slicing_enabled: + logger.warning( + "[realtime] left_overlap=%dms >= unfixed_chunks_duration=%dms; " + "audio slicing disabled, falling back to cumulative inference", + left_overlap_ms, + state.unfixed_chunk_num * int(state.chunk_size_sec * 1000), + ) self.audio = _AudioState( max_buffer_bytes=self.max_buffer_seconds * self.bytes_per_second, chunk_size_bytes=chunk_size_bytes, state=state, + left_overlap_bytes=left_overlap_bytes, + slicing_min_chunk_index=slicing_min_chunk_index, + slicing_enabled=slicing_enabled, ) self.item = _ItemState(current_item_id=f"item_{random_uuid()}") @@ -543,8 +580,7 @@ async def _on_input_audio_buffer_commit( tail = self.audio.state.finalize() await self._emit_transcription_delta(tail) - # Build from emitted_deltas, not state.full_transcript: prefix injection - # means the last chunk's full_transcript is only the continuation tail. + # Use emitted_deltas: under slicing, state.full_transcript is the deduped tail. transcript = normalize_whitespace("".join(self.item.emitted_deltas)) await self._send( @@ -582,17 +618,39 @@ async def _run_inference(self, is_last: bool) -> bool: """Run ASR on the current cumulative buffer. Returns False on failure: commit-time emits transcription.failed and rolls the item; append-time emits a generic error envelope and closes the WebSocket.""" - wav_data = await asyncio.to_thread( - _pcm_to_wav, bytes(self.audio.pcm_buffer), self.model_sample_rate + # Bare prompt under slicing: emitted_text is not injected as a + # continuation prefix; the retained overlap + output dedupe + # takes its place. + committed_text = self.audio.state.get_prefix_text() + slicing_engaged = ( + self.audio.slicing_enabled + and bool(committed_text) + and self.audio.state.chunk_index >= self.audio.slicing_min_chunk_index ) + if slicing_engaged: + prompt: Optional[str] = self.adapter.prompt_template + dedupe_against: Optional[str] = committed_text + slice_start = max( + 0, + self.audio.committed_audio_until_bytes - self.audio.left_overlap_bytes, + ) + else: + prompt = None + dedupe_against = None + slice_start = 0 + + pcm_slice = _slice_pcm_from(self.audio.pcm_buffer, slice_start) + audio_samples = await asyncio.to_thread(_pcm_to_float_samples, pcm_slice) try: delta = await process_asr_chunk( tokenizer_manager=self.tokenizer_manager, adapter=self.adapter, state=self.audio.state, - audio_data=wav_data, + audio_data=audio_samples, sampling_params=self.config.sampling_params, is_last=is_last, + prompt=prompt, + dedupe_against=dedupe_against, ) except Exception: logger.exception( @@ -632,6 +690,9 @@ async def _run_inference(self, is_last: bool) -> bool: ) return False + if slicing_engaged: + self.audio.committed_audio_until_bytes = len(self.audio.pcm_buffer) + self.audio.last_inference_offset = len(self.audio.pcm_buffer) await self._emit_transcription_delta(delta) return True @@ -669,6 +730,7 @@ def _reset_inference_state(self) -> None: self.audio.pcm_buffer.clear() # in-place; reuses the buffer's allocation self.item.emitted_deltas.clear() self.audio.last_inference_offset = 0 + self.audio.committed_audio_until_bytes = 0 def _build_session_info(self) -> TranscriptionSessionConfig: # id / object aren't SDK fields; round-trip via extra='allow' so diff --git a/python/sglang/srt/entrypoints/openai/streaming_asr.py b/python/sglang/srt/entrypoints/openai/streaming_asr.py index a347cc8f3e33..a37f17eeb4b5 100644 --- a/python/sglang/srt/entrypoints/openai/streaming_asr.py +++ b/python/sglang/srt/entrypoints/openai/streaming_asr.py @@ -3,8 +3,9 @@ import logging import re from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union +import numpy as np import soundfile as sf from fastapi import Request @@ -130,18 +131,15 @@ def normalize_whitespace(text: str) -> str: def _is_cjk(c: str) -> bool: - """Whether char is a CJK-context glyph that doesn't take inter-word - spaces — ideographs, Japanese kana, CJK punctuation, fullwidth forms. - Excludes Hangul / Devanagari / Arabic etc., which are non-ASCII but - space-separated and need the normal boundary space.""" + """CJK-context glyph that doesn't take inter-word spaces.""" cp = ord(c) return ( - 0x3000 <= cp <= 0x303F # CJK Symbols and Punctuation (,。、《》「」…) + 0x3000 <= cp <= 0x303F # CJK Symbols and Punctuation or 0x3040 <= cp <= 0x309F # Hiragana or 0x30A0 <= cp <= 0x30FF # Katakana or 0x3400 <= cp <= 0x4DBF # CJK Unified Ideographs Ext A or 0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs - or 0xFF00 <= cp <= 0xFFEF # Halfwidth & Fullwidth Forms (fullwidth ASCII) + or 0xFF00 <= cp <= 0xFFEF # Halfwidth & Fullwidth Forms ) @@ -162,18 +160,120 @@ def needs_space(prev: str, cur: str) -> bool: return True +# Trailing punctuation stripped during dedupe match. Includes em dash +# (U+2014), hyphen-minus, and CJK fullwidth equivalents. +_DEDUPE_NORM_STRIP = ",.!?;:—-,。!?;:、" + + +def _dedupe_norm(word: str) -> str: + """Lowercase + strip trailing punctuation for dedupe matching.""" + return word.strip(_DEDUPE_NORM_STRIP).lower() + + +def _dedupe_word_level(committed_text: str, candidate_out: str) -> str: + """Drop the longest prefix of ``candidate_out`` matching the suffix of + ``committed_text`` word-for-word (case- and punctuation-insensitive).""" + cand_words = candidate_out.split() + if not cand_words: + return candidate_out + c_words = committed_text.split() + if not c_words: + return candidate_out + # Longest possible overlap is bounded by candidate length; normalize + # only that tail of committed text instead of scanning the whole history. + # Pre-normalize once instead of O(k²) calls inside the inner loop, then + # compare list slices in C rather than glyph-by-glyph in Python. + max_k = min(len(c_words), len(cand_words)) + c_norm = [_dedupe_norm(w) for w in c_words[-max_k:]] + cand_norm = [_dedupe_norm(w) for w in cand_words] + for k in range(max_k, 0, -1): + if c_norm[-k:] == cand_norm[:k]: + return " ".join(cand_words[k:]) + return candidate_out + + +def _find_kth_cjk_pos(text: str, k: int) -> Optional[int]: + """Return index after the k-th CJK glyph in text, or None if text + contains fewer than k CJK glyphs.""" + seen = 0 + for i, c in enumerate(text): + if c.isspace() or not _is_cjk(c): + continue + seen += 1 + if seen == k: + return i + 1 + return None + + +def _dedupe_cjk_char_level(committed_text: str, candidate_out: str) -> str: + """Drop leading CJK glyphs of ``candidate_out`` matching the CJK-tail of + ``committed_text``. Non-CJK glyphs are skipped during match, preserved + in trimmed output.""" + cand_chars = [c for c in candidate_out if not c.isspace() and _is_cjk(c)] + if not cand_chars: + return candidate_out + # Longest possible overlap is bounded by candidate CJK length; collect + # only that tail of committed CJK glyphs instead of scanning the whole + # history. We iterate committed_text in reverse and stop once we have + # len(cand_chars) CJK glyphs. + max_cand = len(cand_chars) + c_tail_rev = [] + for c in reversed(committed_text): + if c.isspace() or not _is_cjk(c): + continue + c_tail_rev.append(c) + if len(c_tail_rev) >= max_cand: + break + if not c_tail_rev: + return candidate_out + c_chars = list(reversed(c_tail_rev)) + max_k = min(len(c_chars), len(cand_chars)) + for k in range(max_k, 0, -1): + if c_chars[-k:] != cand_chars[:k]: + continue + cut_pos = _find_kth_cjk_pos(candidate_out, k) + if cut_pos is None: + return "" + return candidate_out[cut_pos:].lstrip() + return candidate_out + + +def dedupe_overlap(committed_text: str, candidate_out: str) -> str: + """Trim words/CJK glyphs at the start of ``candidate_out`` that + re-transcribe ``committed_text``'s tail. Word-level first, CJK + char-level fallback.""" + if not committed_text or not candidate_out: + return candidate_out + deduped = _dedupe_word_level(committed_text, candidate_out) + if deduped != candidate_out: + return deduped + if any(_is_cjk(c) for c in committed_text) or any( + _is_cjk(c) for c in candidate_out + ): + return _dedupe_cjk_char_level(committed_text, candidate_out) + return candidate_out + + async def process_asr_chunk( tokenizer_manager: TokenizerManager, adapter: TranscriptionAdapter, state: StreamingASRState, - audio_data: bytes, + audio_data: Union[bytes, np.ndarray], sampling_params: Dict[str, Any], is_last: bool, raw_request: Optional[Request] = None, routing_key: Optional[str] = None, + prompt: Optional[str] = None, + dedupe_against: Optional[str] = None, ) -> str: - """Run inference on one audio chunk. Shared by the HTTP and WebSocket paths.""" - prompt = adapter.prompt_template + state.get_prefix_text() + """Run inference on one audio chunk. Shared by the HTTP and WS paths. + + ``audio_data`` accepts WAV bytes or pre-decoded float samples. + ``prompt`` overrides the default ``adapter.prompt_template + state.get_prefix_text()``. + ``dedupe_against`` triggers ``dedupe_overlap`` on raw model output before ``state`` ingests it. + """ + if prompt is None: + prompt = adapter.prompt_template + state.get_prefix_text() chunk_request = GenerateReqInput( text=prompt, @@ -202,6 +302,8 @@ async def process_asr_chunk( return "" text = normalize_whitespace(adapter.postprocess_text(ret.get("text", ""))) + if dedupe_against is not None: + text = dedupe_overlap(dedupe_against, text) if is_last: state.full_transcript = text diff --git a/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py b/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py index cd97b42997f9..1120289ec3ff 100644 --- a/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py +++ b/python/sglang/srt/entrypoints/openai/transcription_adapters/base.py @@ -107,6 +107,20 @@ def chunked_streaming_config(self) -> dict: """ return {} + @property + def realtime_slicing_config(self) -> dict: + """Tuning knobs for the WS realtime slicing path. Only consulted + when ``supports_chunked_streaming`` is True. Override per adapter + when the model's token rate or per-chunk stability differs. + + ``left_overlap_ms``: audio kept across the sliced boundary so + dedupe has context; cover the K-token rollback window. + ``min_audio_sec``: don't slice below this many seconds of + cumulative audio (sliced output diverges from cumulative + on short inputs and dedupe over-matches). + """ + return {"left_overlap_ms": 2000, "min_audio_sec": 16.0} + def postprocess_text(self, text: str) -> str: """Strip model-specific markers from raw decoded text. diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index feb505d5dd6b..7bbaaa40204c 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -772,11 +772,20 @@ def set_random_seed(seed: int) -> None: def load_audio( - audio_file: str, sr: Optional[int] = None, mono: bool = True + audio_file: Union[str, bytes, np.ndarray], + sr: Optional[int] = None, + mono: bool = True, ) -> np.ndarray: if sr is None: sr = 16000 + # Caller must pre-resample to `sr`. Multi-channel layout assumed + # (n_samples, n_channels) per soundfile.read. + if isinstance(audio_file, np.ndarray): + if mono and audio_file.ndim > 1: + return np.mean(audio_file, axis=1) + return audio_file + # Normalize input: resolve URL / base64 / file:// to bytes or path if isinstance(audio_file, bytes): source = audio_file diff --git a/test/registered/unit/entrypoints/openai/test_streaming_asr.py b/test/registered/unit/entrypoints/openai/test_streaming_asr.py new file mode 100644 index 000000000000..a73a29e3aebb --- /dev/null +++ b/test/registered/unit/entrypoints/openai/test_streaming_asr.py @@ -0,0 +1,113 @@ +"""Unit tests for realtime ASR slicing-path helpers. + +Edge cases for ``dedupe_overlap`` (normalization rules, CJK fallback, the +suffix-only-history invariant the perf optimization depends on), the +bit-equality invariant for ``_pcm_to_float_samples``, and ``_slice_pcm_from`` +validation. Trivial happy-path assertions that restated Python primitives were +dropped. The slicing trigger logic and its interaction with +``StreamingASRState`` are exercised by the manual GPU suite, not by CI. +""" + +import io +import unittest + +import numpy as np +import soundfile as sf + +from sglang.srt.entrypoints.openai.realtime.session import ( + _pcm_to_float_samples, + _slice_pcm_from, +) +from sglang.srt.entrypoints.openai.streaming_asr import dedupe_overlap +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=2, suite="base-a-test-cpu") + + +class TestDedupeOverlap(CustomTestCase): + """Edge cases for the dedupe heuristic. + + Drops trivial happy-path assertions; keeps cases that lock + normalization rules, CJK fallback paths, and the suffix-only-history + invariant that the perf optimization relies on. + """ + + def test_full_candidate_overlaps_returns_empty(self): + # Whole-candidate match must emit empty so StreamingASRState doesn't + # double-record the previous chunk's content. + self.assertEqual(dedupe_overlap("hello world", "hello world"), "") + + def test_empty_committed_returns_candidate_unchanged(self): + self.assertEqual(dedupe_overlap("", "anything goes"), "anything goes") + + def test_empty_candidate_returns_empty(self): + self.assertEqual(dedupe_overlap("anything", ""), "") + + def test_em_dash_normalized_during_match(self): + # Trailing em dash and case differences are stripped during match. + # Regression test for the dedupe rule documented in _DEDUPE_NORM_STRIP. + self.assertEqual( + dedupe_overlap("stew for dinner—", "Dinner: turnips"), "turnips" + ) + + def test_cjk_char_level_fallback(self): + # No whitespace → word-level returns unchanged → CJK fallback engages. + self.assertEqual(dedupe_overlap("你好世界", "世界今天很好"), "今天很好") + + def test_cjk_overlap_with_punctuation(self): + # CJK punctuation in committed_text must not block the char-level + # match on the ideographs that follow. + self.assertEqual(dedupe_overlap("你好,世界", "世界今天很好"), "今天很好") + + def test_long_committed_history_uses_suffix_overlap(self): + # Locks the suffix-only invariant the tail-only optimization + # depends on: a massive committed prefix unrelated to the candidate + # must not change the match outcome. + committed = " ".join(["old"] * 1000 + ["a", "b", "c"]) + self.assertEqual(dedupe_overlap(committed, "b c d"), "d") + + +class TestPcmToFloatSamples(CustomTestCase): + """The bit-equality invariant the PR's perf claim depends on, plus the + one corruption boundary worth catching loudly.""" + + def test_matches_soundfile_round_trip(self): + # The PCM→WAV→sf.read path was the legacy converter; this PR's + # direct conversion must remain bit-equal to it. + rng = np.random.default_rng(42) + ints = rng.integers(-32768, 32768, size=4096, dtype=np.int16) + pcm = ints.tobytes() + + direct = _pcm_to_float_samples(pcm) + + buf = io.BytesIO() + sf.write(buf, ints, 16000, format="WAV") + buf.seek(0) + round_trip, _ = sf.read(buf) + + np.testing.assert_array_equal(direct, round_trip) + + def test_odd_length_pcm_raises(self): + # int16 frames are 2 bytes; an odd-length buffer means upstream + # corruption. Keep the np.frombuffer ValueError loud — silent + # rounding would mask the bug. + with self.assertRaises(ValueError): + _pcm_to_float_samples(b"\x00") + + +class TestSlicePcmFrom(CustomTestCase): + """Only the validation behavior — the trivial slice cases were + Python-built-in tests, not ours.""" + + def test_negative_start_raises(self): + with self.assertRaises(ValueError): + _slice_pcm_from(b"abcdef", -1) + + def test_past_end_raises(self): + with self.assertRaises(ValueError): + _slice_pcm_from(b"abcdef", 7) + + +if __name__ == "__main__": + unittest.main()