Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 82 additions & 21 deletions python/sglang/srt/entrypoints/openai/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@
from __future__ import annotations

import asyncio
import io
import json
import logging
import math
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pybase64
import soundfile as sf
from fastapi import WebSocket, WebSocketDisconnect
from openai.types.realtime import (
ConversationItemCreatedEvent,
Expand Down Expand Up @@ -83,6 +81,13 @@
_SAMPLE_WIDTH = 2


def _slice_pcm_from(buffer: Union[bytes, bytearray], start: int) -> bytes:
"""Return an immutable ``buffer[start:]`` snapshot with bounds checking."""
if not (0 <= start <= len(buffer)):
raise ValueError(f"_slice_pcm_from: start={start} not in [0, {len(buffer)}]")
return bytes(memoryview(buffer)[start:])
Comment on lines +84 to +88
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

As a defensive programming practice, consider adding a check in _slice_pcm_from to ensure that the start offset is a multiple of _SAMPLE_WIDTH. This guarantees that the sliced buffer is properly aligned to 16-bit PCM boundaries, preventing silent audio corruption or misalignment.

Suggested change
def _slice_pcm_from(buffer: Union[bytes, bytearray], start: int) -> bytes:
"""Return an immutable ``buffer[start:]`` snapshot with bounds checking."""
if not (0 <= start <= len(buffer)):
raise ValueError(f"_slice_pcm_from: start={start} not in [0, {len(buffer)}]")
return bytes(memoryview(buffer)[start:])
def _slice_pcm_from(buffer: Union[bytes, bytearray], start: int) -> bytes:
"""Return an immutable ``buffer[start:]`` snapshot with bounds checking."""
if not (0 <= start <= len(buffer)):
raise ValueError(f"_slice_pcm_from: start={start} not in [0, {len(buffer)}]")
if start % _SAMPLE_WIDTH != 0:
raise ValueError(f"_slice_pcm_from: start={start} must be a multiple of {_SAMPLE_WIDTH}")
return bytes(memoryview(buffer)[start:])



def _resample_to_target_rate(pcm: bytes, src_rate: int, target_rate: int) -> bytes:
if src_rate == target_rate or not pcm:
return pcm
Expand All @@ -99,11 +104,10 @@ def _resample_to_target_rate(pcm: bytes, src_rate: int, target_rate: int) -> byt
return (np.clip(samples, -1.0, 1.0) * 32767.0).astype(np.int16).tobytes()


def _pcm_to_wav(pcm: bytes, sample_rate: int) -> bytes:
samples = np.frombuffer(pcm, dtype=np.int16)
buf = io.BytesIO()
sf.write(buf, samples, sample_rate, format="WAV")
return buf.getvalue()
def _pcm_to_float_samples(pcm: bytes) -> np.ndarray:
# /32768.0 matches soundfile.read's default int16 normalization so the
# samples are bit-equal to the prior PCM→WAV→sf.read path.
return np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0


_CLIENT_EVENT_TYPES: Dict[str, type] = {
Expand Down Expand Up @@ -139,17 +143,23 @@ class _SessionConfig:

@dataclass
class _AudioState:
"""Per-item audio state: PCM buffer accumulated from
input_audio_buffer.append, the chunked ASR rollback state, and the
static buffer-size limits set at __init__. pcm_buffer / state /
last_inference_offset reset on commit-roll and clear; the size limits
stay constant for the session's lifetime."""
"""Per-item audio buffer and slicing state.

After the slicing gate is reached, inference switches from the cumulative
buffer to a tail slice. The first gated call may still start at offset 0;
later calls use ``last_sliced_buffer_end_bytes - left_overlap_bytes``."""

max_buffer_bytes: int
chunk_size_bytes: int
left_overlap_bytes: int
slicing_min_chunk_index: int
state: StreamingASRState
# False when the left overlap covers the whole unfixed-chunk window (the
# K-unfixed dedupe target would be unreachable); set at construction.
slicing_enabled: bool = True
pcm_buffer: bytearray = field(default_factory=bytearray)
last_inference_offset: int = 0
last_sliced_buffer_end_bytes: int = 0


@dataclass
Expand Down Expand Up @@ -190,17 +200,40 @@ def __init__(

self.config = _SessionConfig()

slicing_cfg = adapter.realtime_slicing_config
slicing_opt_in = bool(slicing_cfg.get("enabled", False))
left_overlap_ms = int(slicing_cfg.get("left_overlap_ms", 0))
min_audio_sec = float(slicing_cfg.get("min_audio_sec", 0.0))
left_overlap_bytes = int(left_overlap_ms / 1000 * self.bytes_per_second)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential audio misalignment and corruption, left_overlap_bytes should be explicitly aligned to a multiple of _SAMPLE_WIDTH (2 bytes). If left_overlap_bytes is not aligned, slicing the PCM buffer could cut a 16-bit sample in half, leading to static noise or runtime errors during conversion.

Suggested change
left_overlap_bytes = int(left_overlap_ms / 1000 * self.bytes_per_second)
left_overlap_bytes = int(left_overlap_ms / 1000 * self.bytes_per_second) // _SAMPLE_WIDTH * _SAMPLE_WIDTH


state = StreamingASRState(**adapter.chunked_streaming_config)
chunk_size_bytes = int(state.chunk_size_sec * self.bytes_per_second)
if chunk_size_bytes <= 0:
raise RuntimeError(
f"adapter.chunked_streaming_config produced non-positive "
f"chunk_size_sec; got {state.chunk_size_sec!r}"
)
slicing_min_chunk_index = (
math.ceil(min_audio_sec / state.chunk_size_sec) if slicing_opt_in else 0
)
slicing_enabled = (
slicing_opt_in
and left_overlap_bytes < state.unfixed_chunk_num * chunk_size_bytes
)
if slicing_opt_in and not slicing_enabled:
logger.warning(
"[realtime] left_overlap=%dms >= unfixed_chunks_duration=%dms; "
"audio slicing disabled, falling back to cumulative inference",
left_overlap_ms,
state.unfixed_chunk_num * int(state.chunk_size_sec * 1000),
)
self.audio = _AudioState(
max_buffer_bytes=self.max_buffer_seconds * self.bytes_per_second,
chunk_size_bytes=chunk_size_bytes,
state=state,
left_overlap_bytes=left_overlap_bytes,
slicing_min_chunk_index=slicing_min_chunk_index,
slicing_enabled=slicing_enabled,
)

self.item = _ItemState(current_item_id=f"item_{random_uuid()}")
Expand Down Expand Up @@ -543,8 +576,8 @@ async def _on_input_audio_buffer_commit(
tail = self.audio.state.finalize()
await self._emit_transcription_delta(tail)

# Build from emitted_deltas, not state.full_transcript: prefix injection
# means the last chunk's full_transcript is only the continuation tail.
# Rebuild from emitted_deltas: both paths leave full_transcript only a
# partial tail, while the deltas together are the whole transcript.
transcript = normalize_whitespace("".join(self.item.emitted_deltas))

await self._send(
Expand Down Expand Up @@ -579,20 +612,42 @@ async def _on_input_audio_buffer_clear(
)

async def _run_inference(self, is_last: bool) -> bool:
"""Run ASR on the current cumulative buffer. Returns False on failure:
commit-time emits transcription.failed and rolls the item; append-time
emits a generic error envelope and closes the WebSocket."""
wav_data = await asyncio.to_thread(
_pcm_to_wav, bytes(self.audio.pcm_buffer), self.model_sample_rate
"""Run ASR on the current audio window: the whole PCM buffer
(cumulative) or a tail slice with left overlap + output dedupe
(slicing). Returns False on failure -- commit-time emits
transcription.failed and rolls the item; append-time closes the WS."""
# Slicing uses a bare prompt: the retained overlap + dedupe replace
# injecting emitted_text as a continuation prefix.
committed_text = self.audio.state.get_prefix_text()
use_slicing = (
self.audio.slicing_enabled
and bool(committed_text)
and self.audio.state.chunk_index >= self.audio.slicing_min_chunk_index
)
if use_slicing:
prompt: Optional[str] = self.adapter.prompt_template
dedupe_against: Optional[str] = committed_text
slice_start = max(
0,
self.audio.last_sliced_buffer_end_bytes - self.audio.left_overlap_bytes,
)
else:
prompt = None
dedupe_against = None
slice_start = 0

try:
pcm_slice = _slice_pcm_from(self.audio.pcm_buffer, slice_start)
audio_samples = await asyncio.to_thread(_pcm_to_float_samples, pcm_slice)
delta = await process_asr_chunk(
tokenizer_manager=self.tokenizer_manager,
adapter=self.adapter,
state=self.audio.state,
audio_data=wav_data,
audio_data=audio_samples,
sampling_params=self.config.sampling_params,
is_last=is_last,
prompt=prompt,
dedupe_against=dedupe_against,
)
except Exception:
logger.exception(
Expand Down Expand Up @@ -632,6 +687,11 @@ async def _run_inference(self, is_last: bool) -> bool:
)
return False

if use_slicing:
# Held-back tokens are re-covered only if their audio span fits the
# left overlap; slower speech can drop the earliest (see known limits).
self.audio.last_sliced_buffer_end_bytes = len(self.audio.pcm_buffer)

self.audio.last_inference_offset = len(self.audio.pcm_buffer)
await self._emit_transcription_delta(delta)
return True
Expand Down Expand Up @@ -669,6 +729,7 @@ def _reset_inference_state(self) -> None:
self.audio.pcm_buffer.clear() # in-place; reuses the buffer's allocation
self.item.emitted_deltas.clear()
self.audio.last_inference_offset = 0
self.audio.last_sliced_buffer_end_bytes = 0

def _build_session_info(self) -> TranscriptionSessionConfig:
# id / object aren't SDK fields; round-trip via extra='allow' so
Expand Down
109 changes: 88 additions & 21 deletions python/sglang/srt/entrypoints/openai/streaming_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import io
import logging
import re
import unicodedata
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import numpy as np
import soundfile as sf
from fastapi import Request

Expand Down Expand Up @@ -40,8 +42,8 @@ class StreamingASRState:
unfixed_chunk_num: int
unfixed_token_num: int
confirmed_text: str = ""
# Monotonic accumulator; used as prompt prefix so the model sees a
# natural continuation point, not the rolled-back ``confirmed_text``.
# Monotonic accumulator. Used as the prompt prefix on cumulative paths and
# as the dedupe prefix on the slicing path.
emitted_text: str = ""
full_transcript: str = ""
chunk_index: int = 0
Expand All @@ -53,9 +55,13 @@ def get_prefix_text(self) -> str:

def _record_emit(self, delta: str) -> str:
if delta:
self.emitted_text = (
f"{self.emitted_text} {delta}".strip() if self.emitted_text else delta
)
if self.emitted_text:
# needs_space avoids a space between adjacent CJK characters;
# this accumulator feeds the prompt prefix and the dedupe target.
sep = " " if needs_space(self.emitted_text, delta) else ""
self.emitted_text = f"{self.emitted_text}{sep}{delta}".strip()
else:
self.emitted_text = delta
return delta

def update(self, new_transcript: str) -> str:
Expand All @@ -67,10 +73,9 @@ def update(self, new_transcript: str) -> str:
self.confirmed_text = ""
self.full_transcript = new_transcript
self.chunk_index += 1
if self.confirmed_text.startswith(old_confirmed):
return self._record_emit(self.confirmed_text[len(old_confirmed) :].strip())
# Model revised earlier text, use word level common prefix to avoid
# re-emitting already-sent content and cutting mid-word.
# Word-level common prefix, not char-level startswith: startswith
# sliced mid-word when a confirmed word was extended ("world" ->
# "worldly" emitted "ly").
old_words = old_confirmed.split()
new_words = self.confirmed_text.split()
common_count = 0
Expand Down Expand Up @@ -130,25 +135,24 @@ def normalize_whitespace(text: str) -> str:


def _is_cjk(c: str) -> bool:
"""Whether char is a CJK-context glyph that doesn't take inter-word
spaces — ideographs, Japanese kana, CJK punctuation, fullwidth forms.
Excludes Hangul / Devanagari / Arabic etc., which are non-ASCII but
space-separated and need the normal boundary space."""
"""CJK-context character that takes no inter-word space."""
cp = ord(c)
if 0xFFA0 <= cp <= 0xFFDC: # halfwidth Hangul jamo -- Korean is space-delimited
return False
return (
0x3000 <= cp <= 0x303F # CJK Symbols and Punctuation (,。、《》「」…)
0x3000 <= cp <= 0x303F # CJK Symbols and Punctuation
or 0x3040 <= cp <= 0x309F # Hiragana
or 0x30A0 <= cp <= 0x30FF # Katakana
or 0x30A0 <= cp <= 0x30FF # Katakana (incl. ー / ・)
or 0x3400 <= cp <= 0x4DBF # CJK Unified Ideographs Ext A
or 0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs
or 0xFF00 <= cp <= 0xFFEF # Halfwidth & Fullwidth Forms (fullwidth ASCII)
or 0xFF00 <= cp <= 0xFFEF # Halfwidth & Fullwidth Forms
)


def needs_space(prev: str, cur: str) -> bool:
"""Return whether a boundary space is needed between emitted deltas.

Avoid spaces around punctuation and between adjacent CJK-context glyphs.
Avoid spaces around punctuation and between adjacent CJK-context characters.
Shared by the realtime WS and HTTP SSE chunked streaming paths.
"""
if not prev or not cur:
Expand All @@ -162,18 +166,79 @@ def needs_space(prev: str, cur: str) -> bool:
return True


def _dedupe_norm(word: str) -> str:
"""Normalize a word for overlap matching: NFKC, lowercase, strip edge
punctuation (Unicode category P)."""
word = unicodedata.normalize("NFKC", word)
lo, hi = 0, len(word)
while lo < hi and unicodedata.category(word[lo])[0] == "P":
lo += 1
while hi > lo and unicodedata.category(word[hi - 1])[0] == "P":
hi -= 1
return word[lo:hi].lower()


def _dedupe_by_word(committed_text: str, candidate_out: str) -> str:
"""Drop the longest prefix of ``candidate_out`` matching the suffix of
``committed_text`` word-for-word (case- and punctuation-insensitive)."""
candidate_words = candidate_out.split()
if not candidate_words:
return candidate_out
# Only the last len(candidate_words) committed words can overlap, so rsplit
# the tail instead of tokenizing the whole (growing) committed transcript.
committed_tail = committed_text.rsplit(maxsplit=len(candidate_words))[
-len(candidate_words) :
]
if not committed_tail:
return candidate_out
# Normalize the committed tail and candidate prefix once, then compare slices.
max_overlap = min(len(committed_tail), len(candidate_words))
committed_tail_norm = [_dedupe_norm(w) for w in committed_tail]
candidate_norm = [_dedupe_norm(w) for w in candidate_words[:max_overlap]]
# Longest overlap first; the first match wins.
for overlap in range(max_overlap, 0, -1):
if committed_tail_norm[-overlap:] != candidate_norm[:overlap]:
continue
# Skip all-punctuation overlaps: lone "@"/"#" both normalize to "" and
# would match spuriously.
if not any(candidate_norm[:overlap]):
continue
return " ".join(candidate_words[overlap:])
return candidate_out


def dedupe_overlap(committed_text: str, candidate_out: str) -> str:
"""Trim words at the start of ``candidate_out`` that re-transcribe
``committed_text``'s tail (word-level, case- and punctuation-insensitive).

CJK has no inter-word spaces, so the word-level matcher does not help there;
a character-level CJK dedupe is deferred to M3, where slicing also engages
for CJK (today it stays on the cumulative path)."""
if not committed_text or not candidate_out:
return candidate_out
return _dedupe_by_word(committed_text, candidate_out)


async def process_asr_chunk(
tokenizer_manager: TokenizerManager,
adapter: TranscriptionAdapter,
state: StreamingASRState,
audio_data: bytes,
audio_data: Union[bytes, np.ndarray],
sampling_params: Dict[str, Any],
is_last: bool,
raw_request: Optional[Request] = None,
routing_key: Optional[str] = None,
prompt: Optional[str] = None,
dedupe_against: Optional[str] = None,
) -> str:
"""Run inference on one audio chunk. Shared by the HTTP and WebSocket paths."""
prompt = adapter.prompt_template + state.get_prefix_text()
"""Run inference on one audio chunk. Shared by the HTTP and WS paths.

``audio_data`` accepts WAV bytes or pre-decoded float samples.
``prompt`` overrides the default ``adapter.prompt_template + state.get_prefix_text()``.
``dedupe_against`` triggers ``dedupe_overlap`` on raw model output before ``state`` ingests it.
"""
if prompt is None:
prompt = adapter.prompt_template + state.get_prefix_text()

chunk_request = GenerateReqInput(
text=prompt,
Expand Down Expand Up @@ -202,6 +267,8 @@ async def process_asr_chunk(
return ""

text = normalize_whitespace(adapter.postprocess_text(ret.get("text", "")))
if dedupe_against is not None:
text = dedupe_overlap(dedupe_against, text)

if is_last:
state.full_transcript = text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ def chunked_streaming_config(self) -> dict:
"""
return {}

@property
def realtime_slicing_config(self) -> dict:
"""Slicing-path tuning knobs, off by default -- an adapter opts in by
overriding with ``enabled=True`` and model-tuned values.
``left_overlap_ms`` is the audio kept across the sliced boundary for
dedupe context; ``min_audio_sec`` is the floor below which slicing stays
off.
"""
return {"enabled": False, "left_overlap_ms": 0, "min_audio_sec": 0.0}

def postprocess_text(self, text: str) -> str:
"""Strip model-specific markers from raw decoded text.

Expand Down
Loading
Loading