-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[Feature] Realtime ASR: Input Slicing for Long-Running Realtime ASR Sessions #26853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
48b1509
9c7fa69
66a1554
5d8455d
57fcd0d
9e8530e
ae724c8
c59d414
36b3322
31dbc97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -9,16 +9,14 @@ | |||||
| from __future__ import annotations | ||||||
|
|
||||||
| import asyncio | ||||||
| import io | ||||||
| import json | ||||||
| import logging | ||||||
| import math | ||||||
| from dataclasses import dataclass, field | ||||||
| from typing import Any, Dict, List, Optional | ||||||
| from typing import Any, Dict, List, Optional, Union | ||||||
|
|
||||||
| import numpy as np | ||||||
| import pybase64 | ||||||
| import soundfile as sf | ||||||
| from fastapi import WebSocket, WebSocketDisconnect | ||||||
| from openai.types.realtime import ( | ||||||
| ConversationItemCreatedEvent, | ||||||
|
|
@@ -83,6 +81,13 @@ | |||||
| _SAMPLE_WIDTH = 2 | ||||||
|
|
||||||
|
|
||||||
| def _slice_pcm_from(buffer: Union[bytes, bytearray], start: int) -> bytes: | ||||||
| """Return an immutable ``buffer[start:]`` snapshot with bounds checking.""" | ||||||
| if not (0 <= start <= len(buffer)): | ||||||
| raise ValueError(f"_slice_pcm_from: start={start} not in [0, {len(buffer)}]") | ||||||
| return bytes(memoryview(buffer)[start:]) | ||||||
|
|
||||||
|
|
||||||
| def _resample_to_target_rate(pcm: bytes, src_rate: int, target_rate: int) -> bytes: | ||||||
| if src_rate == target_rate or not pcm: | ||||||
| return pcm | ||||||
|
|
@@ -99,11 +104,10 @@ def _resample_to_target_rate(pcm: bytes, src_rate: int, target_rate: int) -> byt | |||||
| return (np.clip(samples, -1.0, 1.0) * 32767.0).astype(np.int16).tobytes() | ||||||
|
|
||||||
|
|
||||||
| def _pcm_to_wav(pcm: bytes, sample_rate: int) -> bytes: | ||||||
| samples = np.frombuffer(pcm, dtype=np.int16) | ||||||
| buf = io.BytesIO() | ||||||
| sf.write(buf, samples, sample_rate, format="WAV") | ||||||
| return buf.getvalue() | ||||||
| def _pcm_to_float_samples(pcm: bytes) -> np.ndarray: | ||||||
| # /32768.0 matches soundfile.read's default int16 normalization so the | ||||||
| # samples are bit-equal to the prior PCM→WAV→sf.read path. | ||||||
| return np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0 | ||||||
|
|
||||||
|
|
||||||
| _CLIENT_EVENT_TYPES: Dict[str, type] = { | ||||||
|
|
@@ -139,17 +143,23 @@ class _SessionConfig: | |||||
|
|
||||||
| @dataclass | ||||||
| class _AudioState: | ||||||
| """Per-item audio state: PCM buffer accumulated from | ||||||
| input_audio_buffer.append, the chunked ASR rollback state, and the | ||||||
| static buffer-size limits set at __init__. pcm_buffer / state / | ||||||
| last_inference_offset reset on commit-roll and clear; the size limits | ||||||
| stay constant for the session's lifetime.""" | ||||||
| """Per-item audio buffer and slicing state. | ||||||
|
|
||||||
| After the slicing gate is reached, inference switches from the cumulative | ||||||
| buffer to a tail slice. The first gated call may still start at offset 0; | ||||||
| later calls use ``last_sliced_buffer_end_bytes - left_overlap_bytes``.""" | ||||||
|
|
||||||
| max_buffer_bytes: int | ||||||
| chunk_size_bytes: int | ||||||
| left_overlap_bytes: int | ||||||
| slicing_min_chunk_index: int | ||||||
| state: StreamingASRState | ||||||
| # False when the left overlap covers the whole unfixed-chunk window (the | ||||||
| # K-unfixed dedupe target would be unreachable); set at construction. | ||||||
| slicing_enabled: bool = True | ||||||
| pcm_buffer: bytearray = field(default_factory=bytearray) | ||||||
| last_inference_offset: int = 0 | ||||||
| last_sliced_buffer_end_bytes: int = 0 | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
|
|
@@ -190,17 +200,40 @@ def __init__( | |||||
|
|
||||||
| self.config = _SessionConfig() | ||||||
|
|
||||||
| slicing_cfg = adapter.realtime_slicing_config | ||||||
| slicing_opt_in = bool(slicing_cfg.get("enabled", False)) | ||||||
| left_overlap_ms = int(slicing_cfg.get("left_overlap_ms", 0)) | ||||||
| min_audio_sec = float(slicing_cfg.get("min_audio_sec", 0.0)) | ||||||
| left_overlap_bytes = int(left_overlap_ms / 1000 * self.bytes_per_second) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To prevent potential audio misalignment and corruption,
Suggested change
|
||||||
|
|
||||||
| state = StreamingASRState(**adapter.chunked_streaming_config) | ||||||
| chunk_size_bytes = int(state.chunk_size_sec * self.bytes_per_second) | ||||||
| if chunk_size_bytes <= 0: | ||||||
| raise RuntimeError( | ||||||
| f"adapter.chunked_streaming_config produced non-positive " | ||||||
| f"chunk_size_sec; got {state.chunk_size_sec!r}" | ||||||
| ) | ||||||
| slicing_min_chunk_index = ( | ||||||
| math.ceil(min_audio_sec / state.chunk_size_sec) if slicing_opt_in else 0 | ||||||
| ) | ||||||
| slicing_enabled = ( | ||||||
| slicing_opt_in | ||||||
| and left_overlap_bytes < state.unfixed_chunk_num * chunk_size_bytes | ||||||
| ) | ||||||
| if slicing_opt_in and not slicing_enabled: | ||||||
| logger.warning( | ||||||
| "[realtime] left_overlap=%dms >= unfixed_chunks_duration=%dms; " | ||||||
| "audio slicing disabled, falling back to cumulative inference", | ||||||
| left_overlap_ms, | ||||||
| state.unfixed_chunk_num * int(state.chunk_size_sec * 1000), | ||||||
| ) | ||||||
| self.audio = _AudioState( | ||||||
| max_buffer_bytes=self.max_buffer_seconds * self.bytes_per_second, | ||||||
| chunk_size_bytes=chunk_size_bytes, | ||||||
| state=state, | ||||||
| left_overlap_bytes=left_overlap_bytes, | ||||||
| slicing_min_chunk_index=slicing_min_chunk_index, | ||||||
| slicing_enabled=slicing_enabled, | ||||||
| ) | ||||||
|
|
||||||
| self.item = _ItemState(current_item_id=f"item_{random_uuid()}") | ||||||
|
|
@@ -543,8 +576,8 @@ async def _on_input_audio_buffer_commit( | |||||
| tail = self.audio.state.finalize() | ||||||
| await self._emit_transcription_delta(tail) | ||||||
|
|
||||||
| # Build from emitted_deltas, not state.full_transcript: prefix injection | ||||||
| # means the last chunk's full_transcript is only the continuation tail. | ||||||
| # Rebuild from emitted_deltas: both paths leave full_transcript only a | ||||||
| # partial tail, while the deltas together are the whole transcript. | ||||||
| transcript = normalize_whitespace("".join(self.item.emitted_deltas)) | ||||||
|
|
||||||
| await self._send( | ||||||
|
|
@@ -579,20 +612,42 @@ async def _on_input_audio_buffer_clear( | |||||
| ) | ||||||
|
|
||||||
| async def _run_inference(self, is_last: bool) -> bool: | ||||||
| """Run ASR on the current cumulative buffer. Returns False on failure: | ||||||
| commit-time emits transcription.failed and rolls the item; append-time | ||||||
| emits a generic error envelope and closes the WebSocket.""" | ||||||
| wav_data = await asyncio.to_thread( | ||||||
| _pcm_to_wav, bytes(self.audio.pcm_buffer), self.model_sample_rate | ||||||
| """Run ASR on the current audio window: the whole PCM buffer | ||||||
| (cumulative) or a tail slice with left overlap + output dedupe | ||||||
| (slicing). Returns False on failure -- commit-time emits | ||||||
| transcription.failed and rolls the item; append-time closes the WS.""" | ||||||
| # Slicing uses a bare prompt: the retained overlap + dedupe replace | ||||||
| # injecting emitted_text as a continuation prefix. | ||||||
| committed_text = self.audio.state.get_prefix_text() | ||||||
| use_slicing = ( | ||||||
| self.audio.slicing_enabled | ||||||
| and bool(committed_text) | ||||||
| and self.audio.state.chunk_index >= self.audio.slicing_min_chunk_index | ||||||
| ) | ||||||
| if use_slicing: | ||||||
| prompt: Optional[str] = self.adapter.prompt_template | ||||||
| dedupe_against: Optional[str] = committed_text | ||||||
| slice_start = max( | ||||||
| 0, | ||||||
| self.audio.last_sliced_buffer_end_bytes - self.audio.left_overlap_bytes, | ||||||
| ) | ||||||
| else: | ||||||
| prompt = None | ||||||
| dedupe_against = None | ||||||
| slice_start = 0 | ||||||
|
|
||||||
| try: | ||||||
| pcm_slice = _slice_pcm_from(self.audio.pcm_buffer, slice_start) | ||||||
| audio_samples = await asyncio.to_thread(_pcm_to_float_samples, pcm_slice) | ||||||
| delta = await process_asr_chunk( | ||||||
| tokenizer_manager=self.tokenizer_manager, | ||||||
| adapter=self.adapter, | ||||||
| state=self.audio.state, | ||||||
| audio_data=wav_data, | ||||||
| audio_data=audio_samples, | ||||||
| sampling_params=self.config.sampling_params, | ||||||
| is_last=is_last, | ||||||
| prompt=prompt, | ||||||
| dedupe_against=dedupe_against, | ||||||
| ) | ||||||
| except Exception: | ||||||
| logger.exception( | ||||||
|
|
@@ -632,6 +687,11 @@ async def _run_inference(self, is_last: bool) -> bool: | |||||
| ) | ||||||
| return False | ||||||
|
|
||||||
| if use_slicing: | ||||||
| # Held-back tokens are re-covered only if their audio span fits the | ||||||
| # left overlap; slower speech can drop the earliest (see known limits). | ||||||
| self.audio.last_sliced_buffer_end_bytes = len(self.audio.pcm_buffer) | ||||||
|
|
||||||
| self.audio.last_inference_offset = len(self.audio.pcm_buffer) | ||||||
| await self._emit_transcription_delta(delta) | ||||||
| return True | ||||||
|
|
@@ -669,6 +729,7 @@ def _reset_inference_state(self) -> None: | |||||
| self.audio.pcm_buffer.clear() # in-place; reuses the buffer's allocation | ||||||
| self.item.emitted_deltas.clear() | ||||||
| self.audio.last_inference_offset = 0 | ||||||
| self.audio.last_sliced_buffer_end_bytes = 0 | ||||||
|
|
||||||
| def _build_session_info(self) -> TranscriptionSessionConfig: | ||||||
| # id / object aren't SDK fields; round-trip via extra='allow' so | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a defensive programming practice, consider adding a check in
_slice_pcm_fromto ensure that thestartoffset is a multiple of_SAMPLE_WIDTH. This guarantees that the sliced buffer is properly aligned to 16-bit PCM boundaries, preventing silent audio corruption or misalignment.