Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 79 additions & 17 deletions python/sglang/srt/entrypoints/openai/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

import asyncio
import io
import json
import logging
import math
Expand All @@ -18,7 +17,6 @@

import numpy as np
import pybase64
import soundfile as sf
from fastapi import WebSocket, WebSocketDisconnect
from openai.types.realtime import (
ConversationItemCreatedEvent,
Expand Down Expand Up @@ -83,6 +81,15 @@
_SAMPLE_WIDTH = 2


def _slice_pcm_from(buffer, start: int) -> bytes:
"""Immutable snapshot of ``buffer[start:]`` via memoryview — one slice-sized
copy instead of the two ``bytes(bytearray)[start:]`` would do. ``buffer`` is
bytes or bytearray. Raises instead of silently returning empty."""
if not (0 <= start <= len(buffer)):
raise ValueError(f"_slice_pcm_from: start={start} not in [0, {len(buffer)}]")
return bytes(memoryview(buffer)[start:])


def _resample_to_target_rate(pcm: bytes, src_rate: int, target_rate: int) -> bytes:
if src_rate == target_rate or not pcm:
return pcm
Expand All @@ -99,11 +106,10 @@ def _resample_to_target_rate(pcm: bytes, src_rate: int, target_rate: int) -> byt
return (np.clip(samples, -1.0, 1.0) * 32767.0).astype(np.int16).tobytes()


def _pcm_to_wav(pcm: bytes, sample_rate: int) -> bytes:
samples = np.frombuffer(pcm, dtype=np.int16)
buf = io.BytesIO()
sf.write(buf, samples, sample_rate, format="WAV")
return buf.getvalue()
def _pcm_to_float_samples(pcm: bytes) -> np.ndarray:
# /32768.0 matches soundfile.read's default int16 normalization so the
# samples are bit-equal to the prior PCM→WAV→sf.read path.
return np.frombuffer(pcm, dtype=np.int16).astype(np.float64) / 32768.0


_CLIENT_EVENT_TYPES: Dict[str, type] = {
Expand Down Expand Up @@ -139,17 +145,29 @@ class _SessionConfig:

@dataclass
class _AudioState:
"""Per-item audio state: PCM buffer accumulated from
input_audio_buffer.append, the chunked ASR rollback state, and the
static buffer-size limits set at __init__. pcm_buffer / state /
last_inference_offset reset on commit-roll and clear; the size limits
stay constant for the session's lifetime."""
"""Per-item audio state. Once the slicing gate is reached (``state.emitted_text``
non-empty AND ``state.chunk_index >= slicing_min_chunk_index``), inference
switches from the cumulative buffer to a tail slice at
``pcm_buffer[committed_audio_until_bytes - left_overlap_bytes:]``. The FIRST
gated call still starts at offset 0 because ``committed_audio_until_bytes`` is
initialized to 0; only subsequent calls are bounded to the left overlap plus
newly appended audio. ``emitted_text`` is not injected into the prompt — the
retained acoustic overlap plus output-side dedupe takes the place of a
continuation prefix."""

max_buffer_bytes: int
chunk_size_bytes: int
left_overlap_bytes: int
slicing_min_chunk_index: int
state: StreamingASRState
# False when left_overlap covers the whole unfixed-chunk window, which
# leaves the K-unfixed dedupe target unreachable; flipped at session
# construction. When False, _run_inference always takes the cumulative
# path even after emitted_text becomes non-empty.
slicing_enabled: bool = True
pcm_buffer: bytearray = field(default_factory=bytearray)
last_inference_offset: int = 0
committed_audio_until_bytes: int = 0


@dataclass
Expand Down Expand Up @@ -190,17 +208,36 @@ def __init__(

self.config = _SessionConfig()

slicing_cfg = adapter.realtime_slicing_config
left_overlap_ms = int(slicing_cfg["left_overlap_ms"])
min_audio_sec = float(slicing_cfg["min_audio_sec"])
left_overlap_bytes = int(left_overlap_ms / 1000 * self.bytes_per_second)

state = StreamingASRState(**adapter.chunked_streaming_config)
chunk_size_bytes = int(state.chunk_size_sec * self.bytes_per_second)
if chunk_size_bytes <= 0:
raise RuntimeError(
f"adapter.chunked_streaming_config produced non-positive "
f"chunk_size_sec; got {state.chunk_size_sec!r}"
)
slicing_min_chunk_index = math.ceil(min_audio_sec / state.chunk_size_sec)
slicing_enabled = (
left_overlap_bytes < state.unfixed_chunk_num * chunk_size_bytes
)
if not slicing_enabled:
logger.warning(
"[realtime] left_overlap=%dms >= unfixed_chunks_duration=%dms; "
"audio slicing disabled, falling back to cumulative inference",
left_overlap_ms,
state.unfixed_chunk_num * int(state.chunk_size_sec * 1000),
)
self.audio = _AudioState(
max_buffer_bytes=self.max_buffer_seconds * self.bytes_per_second,
chunk_size_bytes=chunk_size_bytes,
state=state,
left_overlap_bytes=left_overlap_bytes,
slicing_min_chunk_index=slicing_min_chunk_index,
slicing_enabled=slicing_enabled,
)

self.item = _ItemState(current_item_id=f"item_{random_uuid()}")
Expand Down Expand Up @@ -543,8 +580,7 @@ async def _on_input_audio_buffer_commit(
tail = self.audio.state.finalize()
await self._emit_transcription_delta(tail)

# Build from emitted_deltas, not state.full_transcript: prefix injection
# means the last chunk's full_transcript is only the continuation tail.
# Use emitted_deltas: under slicing, state.full_transcript is the deduped tail.
transcript = normalize_whitespace("".join(self.item.emitted_deltas))

await self._send(
Expand Down Expand Up @@ -582,17 +618,39 @@ async def _run_inference(self, is_last: bool) -> bool:
"""Run ASR on the current cumulative buffer. Returns False on failure:
commit-time emits transcription.failed and rolls the item; append-time
emits a generic error envelope and closes the WebSocket."""
wav_data = await asyncio.to_thread(
_pcm_to_wav, bytes(self.audio.pcm_buffer), self.model_sample_rate
# Bare prompt under slicing: emitted_text is not injected as a
# continuation prefix; the retained overlap + output dedupe
# takes its place.
committed_text = self.audio.state.get_prefix_text()
slicing_engaged = (
self.audio.slicing_enabled
and bool(committed_text)
and self.audio.state.chunk_index >= self.audio.slicing_min_chunk_index
)
if slicing_engaged:
prompt: Optional[str] = self.adapter.prompt_template
dedupe_against: Optional[str] = committed_text
slice_start = max(
0,
self.audio.committed_audio_until_bytes - self.audio.left_overlap_bytes,
)
else:
prompt = None
dedupe_against = None
slice_start = 0

pcm_slice = _slice_pcm_from(self.audio.pcm_buffer, slice_start)
audio_samples = await asyncio.to_thread(_pcm_to_float_samples, pcm_slice)
try:
delta = await process_asr_chunk(
tokenizer_manager=self.tokenizer_manager,
adapter=self.adapter,
state=self.audio.state,
audio_data=wav_data,
audio_data=audio_samples,
sampling_params=self.config.sampling_params,
is_last=is_last,
prompt=prompt,
dedupe_against=dedupe_against,
)
except Exception:
logger.exception(
Expand Down Expand Up @@ -632,6 +690,9 @@ async def _run_inference(self, is_last: bool) -> bool:
)
return False

if slicing_engaged:
self.audio.committed_audio_until_bytes = len(self.audio.pcm_buffer)

self.audio.last_inference_offset = len(self.audio.pcm_buffer)
await self._emit_transcription_delta(delta)
return True
Expand Down Expand Up @@ -669,6 +730,7 @@ def _reset_inference_state(self) -> None:
self.audio.pcm_buffer.clear() # in-place; reuses the buffer's allocation
self.item.emitted_deltas.clear()
self.audio.last_inference_offset = 0
self.audio.committed_audio_until_bytes = 0

def _build_session_info(self) -> TranscriptionSessionConfig:
# id / object aren't SDK fields; round-trip via extra='allow' so
Expand Down
122 changes: 112 additions & 10 deletions python/sglang/srt/entrypoints/openai/streaming_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import numpy as np
import soundfile as sf
from fastapi import Request

Expand Down Expand Up @@ -130,18 +131,15 @@ def normalize_whitespace(text: str) -> str:


def _is_cjk(c: str) -> bool:
"""Whether char is a CJK-context glyph that doesn't take inter-word
spaces — ideographs, Japanese kana, CJK punctuation, fullwidth forms.
Excludes Hangul / Devanagari / Arabic etc., which are non-ASCII but
space-separated and need the normal boundary space."""
"""CJK-context glyph that doesn't take inter-word spaces."""
cp = ord(c)
return (
0x3000 <= cp <= 0x303F # CJK Symbols and Punctuation (,。、《》「」…)
0x3000 <= cp <= 0x303F # CJK Symbols and Punctuation
or 0x3040 <= cp <= 0x309F # Hiragana
or 0x30A0 <= cp <= 0x30FF # Katakana
or 0x3400 <= cp <= 0x4DBF # CJK Unified Ideographs Ext A
or 0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs
or 0xFF00 <= cp <= 0xFFEF # Halfwidth & Fullwidth Forms (fullwidth ASCII)
or 0xFF00 <= cp <= 0xFFEF # Halfwidth & Fullwidth Forms
)


Expand All @@ -162,18 +160,120 @@ def needs_space(prev: str, cur: str) -> bool:
return True


# Trailing punctuation stripped during dedupe match. Includes em dash
# (U+2014), hyphen-minus, and CJK fullwidth equivalents.
_DEDUPE_NORM_STRIP = ",.!?;:—-,。!?;:、"


def _dedupe_norm(word: str) -> str:
"""Lowercase + strip trailing punctuation for dedupe matching."""
return word.strip(_DEDUPE_NORM_STRIP).lower()


def _dedupe_word_level(committed_text: str, candidate_out: str) -> str:
"""Drop the longest prefix of ``candidate_out`` matching the suffix of
``committed_text`` word-for-word (case- and punctuation-insensitive)."""
cand_words = candidate_out.split()
if not cand_words:
return candidate_out
c_words = committed_text.split()
if not c_words:
return candidate_out
# Longest possible overlap is bounded by candidate length; normalize
# only that tail of committed text instead of scanning the whole history.
# Pre-normalize once instead of O(k²) calls inside the inner loop, then
# compare list slices in C rather than glyph-by-glyph in Python.
max_k = min(len(c_words), len(cand_words))
c_norm = [_dedupe_norm(w) for w in c_words[-max_k:]]
cand_norm = [_dedupe_norm(w) for w in cand_words]
for k in range(max_k, 0, -1):
if c_norm[-k:] == cand_norm[:k]:
return " ".join(cand_words[k:])
return candidate_out


def _find_kth_cjk_pos(text: str, k: int) -> Optional[int]:
"""Return index after the k-th CJK glyph in text, or None if text
contains fewer than k CJK glyphs."""
seen = 0
for i, c in enumerate(text):
if c.isspace() or not _is_cjk(c):
continue
seen += 1
if seen == k:
return i + 1
return None


def _dedupe_cjk_char_level(committed_text: str, candidate_out: str) -> str:
"""Drop leading CJK glyphs of ``candidate_out`` matching the CJK-tail of
``committed_text``. Non-CJK glyphs are skipped during match, preserved
in trimmed output."""
cand_chars = [c for c in candidate_out if not c.isspace() and _is_cjk(c)]
if not cand_chars:
return candidate_out
# Longest possible overlap is bounded by candidate CJK length; collect
# only that tail of committed CJK glyphs instead of scanning the whole
# history. We iterate committed_text in reverse and stop once we have
# len(cand_chars) CJK glyphs.
max_cand = len(cand_chars)
c_tail_rev = []
for c in reversed(committed_text):
if c.isspace() or not _is_cjk(c):
continue
c_tail_rev.append(c)
if len(c_tail_rev) >= max_cand:
break
if not c_tail_rev:
return candidate_out
c_chars = list(reversed(c_tail_rev))
max_k = min(len(c_chars), len(cand_chars))
for k in range(max_k, 0, -1):
if c_chars[-k:] != cand_chars[:k]:
continue
cut_pos = _find_kth_cjk_pos(candidate_out, k)
if cut_pos is None:
return ""
return candidate_out[cut_pos:].lstrip()
return candidate_out


def dedupe_overlap(committed_text: str, candidate_out: str) -> str:
"""Trim words/CJK glyphs at the start of ``candidate_out`` that
re-transcribe ``committed_text``'s tail. Word-level first, CJK
char-level fallback."""
if not committed_text or not candidate_out:
return candidate_out
deduped = _dedupe_word_level(committed_text, candidate_out)
if deduped != candidate_out:
return deduped
if any(_is_cjk(c) for c in committed_text) or any(
_is_cjk(c) for c in candidate_out
):
return _dedupe_cjk_char_level(committed_text, candidate_out)
return candidate_out


async def process_asr_chunk(
tokenizer_manager: TokenizerManager,
adapter: TranscriptionAdapter,
state: StreamingASRState,
audio_data: bytes,
audio_data: Union[bytes, np.ndarray],
sampling_params: Dict[str, Any],
is_last: bool,
raw_request: Optional[Request] = None,
routing_key: Optional[str] = None,
prompt: Optional[str] = None,
dedupe_against: Optional[str] = None,
) -> str:
"""Run inference on one audio chunk. Shared by the HTTP and WebSocket paths."""
prompt = adapter.prompt_template + state.get_prefix_text()
"""Run inference on one audio chunk. Shared by the HTTP and WS paths.

``audio_data`` accepts WAV bytes or pre-decoded float samples.
``prompt`` overrides the default ``adapter.prompt_template + state.get_prefix_text()``.
``dedupe_against`` triggers ``dedupe_overlap`` on raw model output before ``state`` ingests it.
"""
if prompt is None:
prompt = adapter.prompt_template + state.get_prefix_text()

chunk_request = GenerateReqInput(
text=prompt,
Expand Down Expand Up @@ -202,6 +302,8 @@ async def process_asr_chunk(
return ""

text = normalize_whitespace(adapter.postprocess_text(ret.get("text", "")))
if dedupe_against is not None:
text = dedupe_overlap(dedupe_against, text)

if is_last:
state.full_transcript = text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def chunked_streaming_config(self) -> dict:
"""
return {}

@property
def realtime_slicing_config(self) -> dict:
"""Tuning knobs for the WS realtime slicing path. Only consulted
when ``supports_chunked_streaming`` is True. Override per adapter
when the model's token rate or per-chunk stability differs.

``left_overlap_ms``: audio kept across the sliced boundary so
dedupe has context; cover the K-token rollback window.
``min_audio_sec``: don't slice below this many seconds of
cumulative audio (sliced output diverges from cumulative
on short inputs and dedupe over-matches).
"""
return {"left_overlap_ms": 2000, "min_audio_sec": 16.0}

def postprocess_text(self, text: str) -> str:
"""Strip model-specific markers from raw decoded text.

Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,11 +772,20 @@ def set_random_seed(seed: int) -> None:


def load_audio(
audio_file: str, sr: Optional[int] = None, mono: bool = True
audio_file: Union[str, bytes, np.ndarray],
sr: Optional[int] = None,
mono: bool = True,
) -> np.ndarray:
if sr is None:
sr = 16000

# Caller must pre-resample to `sr`. Multi-channel layout assumed
# (n_samples, n_channels) per soundfile.read.
if isinstance(audio_file, np.ndarray):
if mono and audio_file.ndim > 1:
return np.mean(audio_file, axis=1)
return audio_file

# Normalize input: resolve URL / base64 / file:// to bytes or path
if isinstance(audio_file, bytes):
source = audio_file
Expand Down
Loading
Loading