From d9f79b34cf41024913d5cc1f9701dbe30f0ef1fd Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Thu, 23 Oct 2025 23:27:12 +0200 Subject: [PATCH 01/15] stt simplify wip --- .../vision_agents/core/agents/agents.py | 34 +- agents-core/vision_agents/core/edge/types.py | 400 ++++++++++++++++-- .../core/observability/__init__.py | 2 + .../core/observability/metrics.py | 3 + .../vision_agents/core/tts/manual_test.py | 82 ++++ agents-core/vision_agents/core/tts/testing.py | 81 ++++ agents-core/vision_agents/core/tts/tts.py | 377 +++++++++-------- docs/ai/instructions/ai-tts.md | 107 +++-- .../simple_agent_example.py | 6 +- plugins/aws/tests/test_aws.py | 7 +- plugins/cartesia/tests/test_tts.py | 208 ++------- .../vision_agents/plugins/cartesia/tts.py | 40 +- plugins/elevenlabs/tests/test_tts.py | 308 ++------------ .../vision_agents/plugins/elevenlabs/tts.py | 39 +- plugins/fish/tests/test_tts.py | 106 +---- .../fish/vision_agents/plugins/fish/tts.py | 78 ++-- plugins/kokoro/tests/test_tts.py | 172 +------- .../vision_agents/plugins/kokoro/tts.py | 33 +- tests/test_tts_base.py | 215 ++++++++++ 19 files changed, 1213 insertions(+), 1085 deletions(-) create mode 100644 agents-core/vision_agents/core/tts/manual_test.py create mode 100644 agents-core/vision_agents/core/tts/testing.py create mode 100644 tests/test_tts_base.py diff --git a/agents-core/vision_agents/core/agents/agents.py b/agents-core/vision_agents/core/agents/agents.py index 9a65f6a7..2b4cb94e 100644 --- a/agents-core/vision_agents/core/agents/agents.py +++ b/agents-core/vision_agents/core/agents/agents.py @@ -32,6 +32,7 @@ from ..stt.events import STTTranscriptEvent from ..stt.stt import STT from ..tts.tts import TTS +from ..tts.events import TTSAudioEvent from ..turn_detection import TurnDetector, TurnStartedEvent, TurnEndedEvent from ..vad import VAD from ..vad.events import VADAudioEvent @@ -302,6 +303,18 @@ async def on_realtime_agent_speech_transcription( original=event, ) + # Listen for TTS audio events and write audio to the output track + @self.events.subscribe + async def _on_tts_audio(event: TTSAudioEvent): + try: + if self._audio_track and event.audio_data: + from typing import Any, cast + + track_any = cast(Any, self._audio_track) + await track_any.write(event.audio_data) + except Exception as e: + self.logger.error(f"Error writing TTS audio to track: {e}") + @self.events.subscribe async def on_stt_transcript_event_create_response(event: STTTranscriptEvent): if self.realtime_mode or not self.llm: @@ -1016,19 +1029,22 @@ def _prepare_rtc(self): self._audio_track = self.llm.output_track self.logger.info("🎵 Using Realtime provider output track for audio") else: - # TODO: what if we want to transform audio... - # Get the required framerate and stereo setting from TTS plugin, default to 48000 for WebRTC - if self.tts: - framerate = self.tts.get_required_framerate() - stereo = self.tts.get_required_stereo() - else: - framerate = 48000 - stereo = True # Default to stereo for WebRTC + # Default to WebRTC-friendly format unless configured differently + framerate = 48000 + stereo = True self._audio_track = self.edge.create_audio_track( framerate=framerate, stereo=stereo ) + # Inform TTS of desired output format so it can resample accordingly if self.tts: - self.tts.set_output_track(self._audio_track) + channels = 2 if stereo else 1 + try: + self.tts.set_output_format( + sample_rate=framerate, + channels=channels, + ) + except Exception as e: + self.logger.warning(f"Failed to set TTS output format: {e}") # Set up video track if video publishers are available if self.publish_video: diff --git a/agents-core/vision_agents/core/edge/types.py b/agents-core/vision_agents/core/edge/types.py index 5f68b1a3..fb5c1ba1 100644 --- a/agents-core/vision_agents/core/edge/types.py +++ b/agents-core/vision_agents/core/edge/types.py @@ -1,6 +1,5 @@ -#from __future__ import annotations from dataclasses import dataclass -from typing import Any, Optional, NamedTuple +from typing import Any, Optional, NamedTuple, Union, Iterator, AsyncIterator import logging import numpy as np @@ -30,6 +29,7 @@ class Connection(AsyncIOEventEmitter): and a way to receive a callback when the call is ended In the future we might want to forward more events """ + async def close(self): pass @@ -53,6 +53,7 @@ class PcmData(NamedTuple): pts: Optional[int] = None # Presentation timestamp dts: Optional[int] = None # Decode timestamp time_base: Optional[float] = None # Time base for converting timestamps to seconds + channels: int = 1 # Number of channels (1=mono, 2=stereo) @property def duration(self) -> float: @@ -67,8 +68,11 @@ def duration(self) -> float: # For f32 format, each element in the array is one sample (float32) if isinstance(self.samples, np.ndarray): - # Direct count of samples in the numpy array - num_samples = len(self.samples) + # If array has shape (channels, samples), duration uses the samples dimension + if self.samples.ndim == 2: + num_samples = self.samples.shape[-1] + else: + num_samples = len(self.samples) elif isinstance(self.samples, bytes): # If samples is bytes, calculate based on format if self.format == "s16": @@ -93,6 +97,11 @@ def duration(self) -> float: # Calculate duration based on sample rate return num_samples / self.sample_rate + @property + def duration_ms(self) -> float: + """Duration in milliseconds computed from samples and sample rate.""" + return self.duration * 1000.0 + @property def pts_seconds(self) -> Optional[float]: if self.pts is not None and self.time_base is not None: @@ -107,77 +116,390 @@ def dts_seconds(self) -> Optional[float]: @classmethod def from_bytes( - cls, - audio_bytes: bytes, - sample_rate: int = 16000, - format: str = "s16" + cls, + audio_bytes: bytes, + sample_rate: int = 16000, + format: str = "s16", + channels: int = 1, ) -> "PcmData": - """ - Create PcmData from raw audio bytes. - + """Create PcmData from raw PCM bytes (interleaved for multi-channel). + Args: - audio_bytes: Raw audio data as bytes - sample_rate: Sample rate in Hz - format: Audio format (e.g., "s16", "f32") - + audio_bytes: Raw PCM data as bytes. + sample_rate: Sample rate in Hz. + format: Audio sample format, e.g. "s16" or "f32". + channels: Number of channels (1=mono, 2=stereo). + Returns: - PcmData object + PcmData object with numpy samples (mono: 1D, multi-channel: 2D [channels, samples]). """ - audio_array = np.frombuffer(audio_bytes, dtype=np.int16) - return cls(samples=audio_array, sample_rate=sample_rate, format=format) + # Determine dtype and bytes per sample + dtype: Any + width: int + if format == "s16": + dtype = np.int16 + width = 2 + elif format == "f32": + dtype = np.float32 + width = 4 + else: + dtype = np.int16 + width = 2 - def resample(self, target_sample_rate: int) -> "PcmData": + # Ensure buffer aligns to whole samples + if len(audio_bytes) % width != 0: + trimmed = len(audio_bytes) - (len(audio_bytes) % width) + if trimmed <= 0: + return cls( + samples=np.array([], dtype=dtype), + sample_rate=sample_rate, + format=format, + channels=channels, + ) + logger.debug( + "Trimming non-aligned PCM buffer: %d -> %d bytes", + len(audio_bytes), + trimmed, + ) + audio_bytes = audio_bytes[:trimmed] + + arr = np.frombuffer(audio_bytes, dtype=dtype) + if channels > 1 and arr.size > 0: + # Convert interleaved [L,R,L,R,...] to shape (channels, samples) + total_frames = (arr.size // channels) * channels + if total_frames != arr.size: + logger.debug( + "Trimming interleaved frames to channel multiple: %d -> %d elements", + arr.size, + total_frames, + ) + arr = arr[:total_frames] + try: + frames = arr.reshape(-1, channels) + arr = frames.T + except Exception: + logger.warning( + f"Unable to reshape audio buffer to {channels} channels; falling back to 1D" + ) + return cls( + samples=arr, sample_rate=sample_rate, format=format, channels=channels + ) + + @classmethod + def from_data( + cls, + data: Union[bytes, bytearray, memoryview, NDArray], + sample_rate: int = 16000, + format: str = "s16", + channels: int = 1, + ) -> "PcmData": + """Create PcmData from bytes or numpy arrays. + + - bytes-like: interpreted as interleaved PCM per channel. + - numpy arrays: accepts 1D [samples], 2D [channels, samples] or [samples, channels]. + """ + if isinstance(data, (bytes, bytearray, memoryview)): + return cls.from_bytes( + bytes(data), sample_rate=sample_rate, format=format, channels=channels + ) + + if isinstance(data, np.ndarray): + arr = data + # Ensure dtype aligns with format + if format == "s16" and arr.dtype != np.int16: + arr = arr.astype(np.int16) + elif format == "f32" and arr.dtype != np.float32: + arr = arr.astype(np.float32) + + # Normalize shape to (channels, samples) for multi-channel + if arr.ndim == 2: + if arr.shape[0] == channels: + samples_arr = arr + elif arr.shape[1] == channels: + samples_arr = arr.T + else: + # Assume first dimension is channels if ambiguous + samples_arr = arr + elif arr.ndim == 1: + if channels > 1: + try: + frames = arr.reshape(-1, channels) + samples_arr = frames.T + except Exception: + logger.warning( + f"Could not reshape 1D array to {channels} channels; keeping mono" + ) + channels = 1 + samples_arr = arr + else: + samples_arr = arr + else: + # Fallback + samples_arr = arr.reshape(-1) + channels = 1 + + return cls( + samples=samples_arr, + sample_rate=sample_rate, + format=format, + channels=channels, + ) + + # Unsupported type + raise TypeError(f"Unsupported data type for PcmData: {type(data)}") + + def resample( + self, target_sample_rate: int, target_channels: Optional[int] = None + ) -> "PcmData": """ - Resample PcmData to a different sample rate using AV library. - + Resample PcmData to a different sample rate and/or channels using AV library. + Args: target_sample_rate: Target sample rate in Hz - + target_channels: Target number of channels (defaults to current) + Returns: New PcmData object with resampled audio """ - if self.sample_rate == target_sample_rate: + if target_channels is None: + target_channels = self.channels + if self.sample_rate == target_sample_rate and target_channels == self.channels: return self - - # Ensure samples are 2D for AV library (channels, samples) + + # Prepare ndarray shape for AV. + # Our convention: (channels, samples) for multi-channel, (samples,) for mono. samples = self.samples if samples.ndim == 1: - # Reshape 1D array to 2D (1 channel, samples) + # Mono: reshape to (1, samples) for AV samples = samples.reshape(1, -1) - + elif samples.ndim == 2: + # Already (channels, samples) + pass + # Create AV audio frame from the samples - frame = av.AudioFrame.from_ndarray(samples, format='s16', layout='mono') + in_layout = "mono" if self.channels == 1 else "stereo" + # For multi-channel, use planar format to avoid packed shape errors + in_format = "s16" if self.channels == 1 else "s16p" + samples = np.ascontiguousarray(samples) + frame = av.AudioFrame.from_ndarray(samples, format=in_format, layout=in_layout) frame.sample_rate = self.sample_rate - + # Create resampler + out_layout = "mono" if target_channels == 1 else "stereo" resampler = av.AudioResampler( - format='s16', - layout='mono', - rate=target_sample_rate + format="s16", layout=out_layout, rate=target_sample_rate ) - + # Resample the frame resampled_frames = resampler.resample(frame) if resampled_frames: resampled_frame = resampled_frames[0] resampled_samples = resampled_frame.to_ndarray() - + # AV returns (channels, samples), so for mono we want the first (and only) channel if len(resampled_samples.shape) > 1: - # Take the first channel (mono) - resampled_samples = resampled_samples[0] - + if target_channels == 1: + resampled_samples = resampled_samples[0] + # Convert to int16 resampled_samples = resampled_samples.astype(np.int16) - + return PcmData( samples=resampled_samples, sample_rate=target_sample_rate, format=self.format, pts=self.pts, dts=self.dts, - time_base=self.time_base + time_base=self.time_base, + channels=target_channels, ) else: # If resampling failed, return original data return self + + def to_bytes(self) -> bytes: + """Return interleaved PCM bytes (s16 or f32 depending on format).""" + arr = self.samples + if isinstance(arr, np.ndarray): + if arr.ndim == 2: + # (channels, samples) -> interleaved (samples, channels) + interleaved = arr.T.reshape(-1) + return interleaved.tobytes() + return arr.tobytes() + # Fallback + if isinstance(arr, (bytes, bytearray)): + return bytes(arr) + try: + return bytes(arr) + except Exception: + logger.warning("Cannot convert samples to bytes; returning empty") + return b"" + + def to_wav_bytes(self) -> bytes: + """Return a complete WAV file (header + frames) as bytes. + + Notes: + - If the data format is not s16, it will be converted to s16. + - Channels and sample rate are taken from the PcmData instance. + """ + import io + import wave + + # Ensure s16 frames + if self.format != "s16": + arr = self.samples + if isinstance(arr, np.ndarray): + if arr.dtype != np.int16: + # Convert floats to int16 range + if arr.dtype != np.float32: + arr = arr.astype(np.float32) + arr = (np.clip(arr, -1.0, 1.0) * 32767.0).astype(np.int16) + frames = PcmData( + samples=arr, + sample_rate=self.sample_rate, + format="s16", + pts=self.pts, + dts=self.dts, + time_base=self.time_base, + channels=self.channels, + ).to_bytes() + else: + frames = self.to_bytes() + width = 2 + else: + frames = self.to_bytes() + width = 2 + + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(self.channels or 1) + wf.setsampwidth(width) + wf.setframerate(self.sample_rate) + wf.writeframes(frames) + return buf.getvalue() + + @classmethod + def from_response( + cls, + response: Any, + *, + sample_rate: int = 16000, + channels: int = 1, + format: str = "s16", + ) -> Union["PcmData", Iterator["PcmData"], AsyncIterator["PcmData"]]: + """Create PcmData stream(s) from a provider response. + + Supported inputs: + - bytes/bytearray/memoryview -> returns PcmData + - async iterator of bytes or objects with .data -> returns async iterator of PcmData + - iterator of bytes or objects with .data -> returns iterator of PcmData + - already PcmData -> returns PcmData + - single object with .data -> returns PcmData from its data + """ + + # bytes-like returns a single PcmData + if isinstance(response, (bytes, bytearray, memoryview)): + return cls.from_bytes( + bytes(response), + sample_rate=sample_rate, + channels=channels, + format=format, + ) + + # Already a PcmData + if isinstance(response, PcmData): + return response + + # Async iterator + if hasattr(response, "__aiter__"): + + async def _agen(): + width = 2 if format == "s16" else 4 if format == "f32" else 2 + frame_width = width * max(1, channels) + buf = bytearray() + async for item in response: + if isinstance(item, PcmData): + yield item + continue + data = getattr(item, "data", item) + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("Async iterator yielded unsupported item type") + buf.extend(bytes(data)) + aligned = (len(buf) // frame_width) * frame_width + if aligned: + chunk = bytes(buf[:aligned]) + del buf[:aligned] + yield cls.from_bytes( + chunk, + sample_rate=sample_rate, + channels=channels, + format=format, + ) + # pad remainder, if any + if buf: + pad_len = (-len(buf)) % frame_width + if pad_len: + buf.extend(b"\x00" * pad_len) + yield cls.from_bytes( + bytes(buf), + sample_rate=sample_rate, + channels=channels, + format=format, + ) + + return _agen() + + # Sync iterator (but skip treating bytes as iterable of ints) + if hasattr(response, "__iter__") and not isinstance( + response, (str, bytes, bytearray, memoryview) + ): + + def _gen(): + width = 2 if format == "s16" else 4 if format == "f32" else 2 + frame_width = width * max(1, channels) + buf = bytearray() + for item in response: + if isinstance(item, PcmData): + yield item + continue + data = getattr(item, "data", item) + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("Iterator yielded unsupported item type") + buf.extend(bytes(data)) + aligned = (len(buf) // frame_width) * frame_width + if aligned: + chunk = bytes(buf[:aligned]) + del buf[:aligned] + yield cls.from_bytes( + chunk, + sample_rate=sample_rate, + channels=channels, + format=format, + ) + if buf: + pad_len = (-len(buf)) % frame_width + if pad_len: + buf.extend(b"\x00" * pad_len) + yield cls.from_bytes( + bytes(buf), + sample_rate=sample_rate, + channels=channels, + format=format, + ) + + return _gen() + + # Single object with .data + if hasattr(response, "data"): + data = getattr(response, "data") + if isinstance(data, (bytes, bytearray, memoryview)): + return cls.from_bytes( + bytes(data), + sample_rate=sample_rate, + channels=channels, + format=format, + ) + + raise TypeError( + f"Unsupported response type for PcmData.from_response: {type(response)}" + ) diff --git a/agents-core/vision_agents/core/observability/__init__.py b/agents-core/vision_agents/core/observability/__init__.py index cbdfbd52..fe1420c0 100644 --- a/agents-core/vision_agents/core/observability/__init__.py +++ b/agents-core/vision_agents/core/observability/__init__.py @@ -15,6 +15,7 @@ tts_first_byte_ms, tts_bytes_streamed, tts_errors, + tts_events_emitted, inflight_ops, CALL_ATTRS, ) @@ -30,6 +31,7 @@ "tts_first_byte_ms", "tts_bytes_streamed", "tts_errors", + "tts_events_emitted", "inflight_ops", "CALL_ATTRS", ] diff --git a/agents-core/vision_agents/core/observability/metrics.py b/agents-core/vision_agents/core/observability/metrics.py index ac5edb97..066e215a 100644 --- a/agents-core/vision_agents/core/observability/metrics.py +++ b/agents-core/vision_agents/core/observability/metrics.py @@ -58,6 +58,9 @@ "tts.bytes.streamed", unit="By", description="Bytes sent/received for TTS" ) tts_errors = meter.create_counter("tts.errors", description="TTS errors") +tts_events_emitted = meter.create_counter( + "tts.events.emitted", description="Number of TTS events emitted" +) inflight_ops = meter.create_up_down_counter( "voice.ops.inflight", description="Inflight voice ops" diff --git a/agents-core/vision_agents/core/tts/manual_test.py b/agents-core/vision_agents/core/tts/manual_test.py new file mode 100644 index 00000000..4d2473d4 --- /dev/null +++ b/agents-core/vision_agents/core/tts/manual_test.py @@ -0,0 +1,82 @@ +import asyncio +import os +import shutil +import tempfile +import time +from typing import Optional + +from vision_agents.core.tts import TTS +from vision_agents.core.tts.testing import TTSSession +from vision_agents.core.edge.types import PcmData + + +async def manual_tts_to_wav( + tts: TTS, + *, + sample_rate: int = 16000, + channels: int = 1, + text: str = "This is a manual TTS playback test.", + outfile_path: Optional[str] = None, + timeout_s: float = 20.0, + play_env: str = "FFPLAY", +) -> str: + """Generate TTS audio to a WAV file and optionally play with ffplay. + + - Creates the TTS instance via `tts_factory()`. + - Sets desired output format via `set_output_format(sample_rate, channels)`. + - Sends `text` and captures TTSAudioEvent chunks. + - Writes a WAV (s16) file and returns the path. + - If env `play_env` is set to "1" and `ffplay` exists, it plays the file. + + Args: + tts: the TTS instance. + sample_rate: desired sample rate to write. + channels: desired channels to write. + text: text to synthesize. + outfile_path: optional absolute path for the WAV file; if None, temp path. + timeout_s: timeout for first audio to arrive. + play_env: env var name controlling playback (default: FFPLAY). + + Returns: + Path to written WAV file. + """ + + tts.set_output_format(sample_rate=sample_rate, channels=channels) + session = TTSSession(tts) + await tts.send(text) + result = await session.wait_for_result(timeout=timeout_s) + if result.errors: + raise RuntimeError(f"TTS errors: {result.errors}") + + # Write WAV file (16kHz mono, s16) + if outfile_path is None: + tmpdir = tempfile.gettempdir() + timestamp = int(time.time()) + outfile_path = os.path.join( + tmpdir, f"tts_manual_test_{tts.__class__.__name__}_{timestamp}.wav" + ) + + pcm_bytes = b"".join(result.speeches) + pcm = PcmData.from_bytes( + pcm_bytes, sample_rate=sample_rate, channels=channels, format="s16" + ) + with open(outfile_path, "wb") as f: + f.write(pcm.to_wav_bytes()) + + # Optional playback + if os.environ.get(play_env) == "1" and shutil.which("ffplay"): + proc = await asyncio.create_subprocess_exec( + "ffplay", + "-autoexit", + "-nodisp", + "-hide_banner", + "-loglevel", + "error", + outfile_path, + ) + try: + await asyncio.wait_for(proc.wait(), timeout=30.0) + except asyncio.TimeoutError: + proc.kill() + + return outfile_path diff --git a/agents-core/vision_agents/core/tts/testing.py b/agents-core/vision_agents/core/tts/testing.py new file mode 100644 index 00000000..1c291d0a --- /dev/null +++ b/agents-core/vision_agents/core/tts/testing.py @@ -0,0 +1,81 @@ +from __future__ import annotations +import asyncio +from dataclasses import dataclass, field +from typing import List + +from . import TTS +from .events import ( + TTSAudioEvent, + TTSErrorEvent, + TTSSynthesisStartEvent, + TTSSynthesisCompleteEvent, +) + + +@dataclass +class TTSResult: + speeches: List[bytes] = field(default_factory=list) + errors: List[Exception] = field(default_factory=list) + started: bool = False + completed: bool = False + + +class TTSSession: + """Test helper to collect TTS events and wait for outcomes. + + Usage: + session = TTSSession(tts) + await tts.send(text) + result = await session.wait_for_result(timeout=10.0) + assert not result.errors + assert result.speeches[0] + """ + + def __init__(self, tts: TTS): + self._tts = tts + self._speeches: List[bytes] = [] + self._errors: List[Exception] = [] + self._started = False + self._completed = False + self._first_event = asyncio.Event() + + @tts.events.subscribe + async def _on_start(ev: TTSSynthesisStartEvent): # type: ignore[name-defined] + self._started = True + + @tts.events.subscribe + async def _on_audio(ev: TTSAudioEvent): # type: ignore[name-defined] + if ev.audio_data: + self._speeches.append(ev.audio_data) + self._first_event.set() + + @tts.events.subscribe + async def _on_error(ev: TTSErrorEvent): # type: ignore[name-defined] + if ev.error: + self._errors.append(ev.error) + self._first_event.set() + + @tts.events.subscribe + async def _on_complete(ev: TTSSynthesisCompleteEvent): # type: ignore[name-defined] + self._completed = True + + @property + def speeches(self) -> List[bytes]: + return self._speeches + + @property + def errors(self) -> List[Exception]: + return self._errors + + async def wait_for_result(self, timeout: float = 10.0) -> TTSResult: + try: + await asyncio.wait_for(self._first_event.wait(), timeout=timeout) + except asyncio.TimeoutError: + # Return whatever we have so far + pass + return TTSResult( + speeches=list(self._speeches), + errors=list(self._errors), + started=self._started, + completed=self._completed, + ) diff --git a/agents-core/vision_agents/core/tts/tts.py b/agents-core/vision_agents/core/tts/tts.py index 5653fa8c..dc63dc50 100644 --- a/agents-core/vision_agents/core/tts/tts.py +++ b/agents-core/vision_agents/core/tts/tts.py @@ -1,11 +1,9 @@ import abc import logging -import inspect import time import uuid -from typing import Optional, Dict, Any, Union, Iterator, AsyncIterator +from typing import Optional, Dict, Union, Iterator, AsyncIterator, AsyncGenerator, Any -from getstream.video.rtc.audio_track import AudioStreamTrack from vision_agents.core.events.manager import EventManager from . import events @@ -15,7 +13,18 @@ TTSSynthesisCompleteEvent, TTSErrorEvent, ) -from vision_agents.core.events import PluginInitializedEvent, PluginClosedEvent +from vision_agents.core.events import ( + PluginInitializedEvent, + PluginClosedEvent, + AudioFormat, +) +from ..observability import ( + tts_latency_ms, + tts_bytes_streamed, + tts_errors, + tts_events_emitted, +) +from ..edge.types import PcmData logger = logging.getLogger(__name__) @@ -27,7 +36,7 @@ class TTS(abc.ABC): This abstract class provides the interface for text-to-speech implementations. It handles: - Converting text to speech - - Sending audio data to an output track + - Resampling and rechanneling audio to a desired format - Emitting audio events Events: @@ -47,60 +56,134 @@ def __init__(self, provider_name: Optional[str] = None): provider_name: Name of the TTS provider (e.g., "cartesia", "elevenlabs") """ super().__init__() - self._track: Optional[AudioStreamTrack] = None self.session_id = str(uuid.uuid4()) self.provider_name = provider_name or self.__class__.__name__ self.events = EventManager() self.events.register_events_from_module(events, ignore_not_compatible=True) - self.events.send(PluginInitializedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="TTS", - provider=self.provider_name, - )) - - def set_output_track(self, track: AudioStreamTrack) -> None: - """ - Set the audio track to output speech to. - - Args: - track: The audio track object that will receive speech audio - """ - self._track = track + # Desired output audio format (what downstream audio track expects) + # Agent can override via set_output_format + self._desired_sample_rate: int = 16000 + self._desired_channels: int = 1 + self._desired_format: AudioFormat = AudioFormat.PCM_S16 + # Native/provider audio format default (used only if plugin returns raw bytes) + self._native_sample_rate: int = 16000 + self._native_channels: int = 1 + self._native_format: AudioFormat = AudioFormat.PCM_S16 + self.events.send( + PluginInitializedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + plugin_type="TTS", + provider=self.provider_name, + ) + ) - @property - def track(self): - """Get the current output track.""" - return self._track + def set_output_format( + self, + sample_rate: int, + channels: int = 1, + audio_format: AudioFormat = AudioFormat.PCM_S16, + ) -> None: + """Set the desired output audio format for emitted events. - def get_required_framerate(self) -> int: - """ - Get the required framerate for the audio track. - - This method should be overridden by subclasses to return their specific - framerate requirement. Defaults to 16000 Hz. - - Returns: - The required framerate in Hz - """ - return 16000 + The agent should call this with its output track properties so this + TTS instance can resample and rechannel audio appropriately. - def get_required_stereo(self) -> bool: - """ - Get whether the audio track should be stereo or mono. - - This method should be overridden by subclasses to return their specific - stereo requirement. Defaults to False (mono). - - Returns: - True if stereo is required, False for mono + Args: + sample_rate: Desired sample rate in Hz (e.g., 48000) + channels: Desired channel count (1 for mono, 2 for stereo) + audio_format: Desired audio format (defaults to PCM S16) """ - return False + self._desired_sample_rate = int(sample_rate) + self._desired_channels = int(channels) + self._desired_format = audio_format + + # Backwards-compatibility helper if any subclass still calls it + def set_native_format(self, sample_rate: int, channels: int = 1) -> None: + self._native_sample_rate = int(sample_rate) + self._native_channels = int(channels) + + def _normalize_to_pcm(self, item: Union[bytes, bytearray, PcmData, Any]) -> PcmData: + """Normalize a chunk to PcmData using the native provider format.""" + if isinstance(item, PcmData): + return item + data = getattr(item, "data", item) + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("Chunk is not bytes or PcmData") + fmt = ( + self._native_format.value + if hasattr(self._native_format, "value") + else "s16" + ) + return PcmData.from_bytes( + bytes(data), + sample_rate=self._native_sample_rate, + channels=self._native_channels, + format=fmt, + ) + + async def _iter_pcm(self, resp: Any) -> AsyncGenerator[PcmData, None]: + """Yield PcmData chunks from a provider response of various shapes.""" + # Single buffer or PcmData + if isinstance(resp, (bytes, bytearray, PcmData)): + yield self._normalize_to_pcm(resp) + return + # Async iterable + if hasattr(resp, "__aiter__"): + async for item in resp: + yield self._normalize_to_pcm(item) + return + # Sync iterable (avoid treating bytes-like as iterable of ints) + if hasattr(resp, "__iter__") and not isinstance(resp, (str, bytes, bytearray)): + for item in resp: + yield self._normalize_to_pcm(item) + return + raise TypeError(f"Unsupported return type from stream_audio: {type(resp)}") + + def _emit_chunk( + self, + pcm: PcmData, + idx: int, + is_final: bool, + synthesis_id: str, + text: str, + user: Optional[Dict[str, Any]], + ) -> tuple[int, float]: + """Resample, serialize, emit TTSAudioEvent; return (bytes_len, duration_ms).""" + pcm_out = pcm.resample(self._desired_sample_rate, self._desired_channels) + payload = pcm_out.to_bytes() + # Metrics: counters per chunk + attrs = {"tts_class": self.__class__.__name__} + tts_bytes_streamed.add(len(payload), attributes=attrs) + tts_events_emitted.add(1, attributes=attrs) + self.events.send( + TTSAudioEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + audio_data=payload, + synthesis_id=synthesis_id, + text_source=text, + user_metadata=user, + chunk_index=idx, + is_final_chunk=is_final, + audio_format=self._desired_format, + sample_rate=self._desired_sample_rate, + channels=self._desired_channels, + ) + ) + return len(payload), pcm_out.duration_ms @abc.abstractmethod async def stream_audio( self, text: str, *args, **kwargs - ) -> Union[bytes, Iterator[bytes], AsyncIterator[bytes]]: + ) -> Union[ + bytes, + Iterator[bytes], + AsyncIterator[bytes], + PcmData, + Iterator[PcmData], + AsyncIterator[PcmData], + ]: """ Convert text to speech audio data. @@ -134,126 +217,62 @@ async def send( self, text: str, user: Optional[Dict[str, Any]] = None, *args, **kwargs ): """ - Convert text to speech, send to the output track, and emit an audio event. + Convert text to speech and emit audio events with the desired format. Args: text: The text to convert to speech user: Optional user metadata to include with the audio event *args: Additional arguments **kwargs: Additional keyword arguments - - Raises: - ValueError: If no output track has been set """ - if self._track is None: - raise ValueError("No output track set. Call set_output_track() first.") - try: - # Log start of synthesis - start_time = time.time() - synthesis_id = str(uuid.uuid4()) + start_time = time.time() + synthesis_id = str(uuid.uuid4()) - logger.debug( - "Starting text-to-speech synthesis", extra={"text_length": len(text)} - ) + logger.debug( + "Starting text-to-speech synthesis", extra={"text_length": len(text)} + ) - self.events.send(TTSSynthesisStartEvent( + self.events.send( + TTSSynthesisStartEvent( session_id=self.session_id, plugin_name=self.provider_name, text=text, synthesis_id=synthesis_id, user_metadata=user, - )) + ) + ) - # Synthesize audio - audio_data = await self.stream_audio(text, *args, **kwargs) + try: + # Synthesize audio in provider-native format + response = await self.stream_audio(text, *args, **kwargs) - # Calculate synthesis time + # Calculate synthesis setup time synthesis_time = time.time() - start_time - # Track total audio duration and bytes total_audio_bytes = 0 - audio_chunks = 0 - - if isinstance(audio_data, bytes): - total_audio_bytes = len(audio_data) - audio_chunks = 1 - await self._track.write(audio_data) + total_audio_ms = 0.0 + chunk_index = 0 - audio_event = TTSAudioEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=audio_data, - synthesis_id=synthesis_id, - text_source=text, - user_metadata=user, - sample_rate=self._track.framerate if self._track else 16000, + # Fast-path: single buffer -> mark final + if isinstance(response, (bytes, bytearray, PcmData)): + bytes_len, dur_ms = self._emit_chunk( + self._normalize_to_pcm(response), 0, True, synthesis_id, text, user ) - self.events.send(audio_event) # Structured event - elif inspect.isasyncgen(audio_data): - async for chunk in audio_data: - if isinstance(chunk, bytes): - total_audio_bytes += len(chunk) - audio_chunks += 1 - await self._track.write(chunk) - - # Emit structured audio event - self.events.send(TTSAudioEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=chunk, - synthesis_id=synthesis_id, - text_source=text, - user_metadata=user, - chunk_index=audio_chunks - 1, - is_final_chunk=False, # We don't know if it's final yet - sample_rate=self._track.framerate if self._track else 16000, - )) - else: # assume it's a Cartesia TTS chunk object - total_audio_bytes += len(chunk.data) - audio_chunks += 1 - await self._track.write(chunk.data) - - self.events.send(TTSAudioEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=chunk.data, - synthesis_id=synthesis_id, - text_source=text, - user_metadata=user, - chunk_index=audio_chunks - 1, - is_final_chunk=False, # We don't know if it's final yet - sample_rate=self._track.framerate if self._track else 16000, - )) - elif hasattr(audio_data, "__iter__") and not isinstance( - audio_data, (str, bytes, bytearray) - ): - for chunk in audio_data: - total_audio_bytes += len(chunk) - audio_chunks += 1 - await self._track.write(chunk) - - self.events.send(TTSAudioEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - audio_data=chunk, - synthesis_id=synthesis_id, - text_source=text, - user_metadata=user, - chunk_index=audio_chunks - 1, - is_final_chunk=False, # We don't know if it's final yet - sample_rate=self._track.framerate if self._track else 16000, - )) + total_audio_bytes += bytes_len + total_audio_ms += dur_ms + chunk_index = 1 else: - raise TypeError( - f"Unsupported return type from synthesize: {type(audio_data)}" - ) + async for pcm in self._iter_pcm(response): + bytes_len, dur_ms = self._emit_chunk( + pcm, chunk_index, False, synthesis_id, text, user + ) + total_audio_bytes += bytes_len + total_audio_ms += dur_ms + chunk_index += 1 - # Estimate audio duration - this is approximate without knowing format details - # Use track framerate if available, otherwise assume 16kHz - sample_rate = self._track.framerate if self._track else 16000 - # For s16 format (16-bit samples), each byte is half a sample - estimated_audio_duration_ms = (total_audio_bytes / 2) / (sample_rate / 1000) + # Use accumulated PcmData duration for total audio duration + estimated_audio_duration_ms = total_audio_ms real_time_factor = ( (synthesis_time * 1000) / estimated_audio_duration_ms @@ -261,38 +280,50 @@ async def send( else None ) - self.events.send(TTSSynthesisCompleteEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - synthesis_id=synthesis_id, - text=text, - user_metadata=user, - total_audio_bytes=total_audio_bytes, - synthesis_time_ms=synthesis_time * 1000, - audio_duration_ms=estimated_audio_duration_ms, - chunk_count=audio_chunks, - real_time_factor=real_time_factor, - )) + self.events.send( + TTSSynthesisCompleteEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + synthesis_id=synthesis_id, + text=text, + user_metadata=user, + total_audio_bytes=total_audio_bytes, + synthesis_time_ms=synthesis_time * 1000, + audio_duration_ms=estimated_audio_duration_ms, + chunk_count=chunk_index, + real_time_factor=real_time_factor, + ) + ) except Exception as e: - self.events.send(TTSErrorEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - error=e, - context="synthesis", - text_source=text, - synthesis_id=synthesis_id, - user_metadata=user, - )) - # ASK: why ? - # Re-raise to allow the caller to handle the error + # Metrics: error counter + tts_errors.add(1, attributes={"tts_class": self.__class__.__name__}) + self.events.send( + TTSErrorEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + error=e, + context="synthesis", + text_source=text, + synthesis_id=synthesis_id or None, + user_metadata=user, + ) + ) raise + finally: + # Metrics: latency histogram for the entire send call + elapsed_ms = (time.time() - start_time) * 1000.0 + tts_latency_ms.record( + elapsed_ms, attributes={"tts_class": self.__class__.__name__} + ) async def close(self): """Close the TTS service and release any resources.""" - self.events.send(PluginClosedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="TTS", - provider=self.provider_name, - cleanup_successful=True, - )) + self.events.send( + PluginClosedEvent( + session_id=self.session_id, + plugin_name=self.provider_name, + plugin_type="TTS", + provider=self.provider_name, + cleanup_successful=True, + ) + ) diff --git a/docs/ai/instructions/ai-tts.md b/docs/ai/instructions/ai-tts.md index 7d694f9f..8cc7552e 100644 --- a/docs/ai/instructions/ai-tts.md +++ b/docs/ai/instructions/ai-tts.md @@ -1,40 +1,89 @@ -## TTS +## TTS Plugin Guide -Here's a minimal example for building a new TTS plugin +Build a TTS plugin that streams audio and emits events. Keep it minimal and follow the project’s layout conventions. -```python +What to create (PEP 420 structure) +- PEP 420: Do NOT add `__init__.py` in plugin folders. Use this layout: + - `plugins//pyproject.toml` (depends on `vision-agents`) + - `plugins//vision_agents/plugins//tts.py` + - `plugins//tests/test_tts.py` (pytest tests at plugin root) + - `plugins//example/` (optional, see `plugins/fish/example/fish_tts_example.py`) -class MyTTS(tts.TTS): - def __init__( - self, - voice_id: str = "VR6AewLTigWG4xSOukaG", # Default ElevenLabs voice - model_id: str = "eleven_multilingual_v2", - client: Optional[MyClient] = None, - ): - # it should be possible to pass the client (makes it easier for users to customize things) - # settings that are common to change, like voice id or model id should be configurable as well - super().__init__() - self.voice_id = voice_id - self.client = client if client is not None else MyClient(api_key=api_key) +Implementation essentials - async def stream_audio(self, text: str, *_, **__) -> AsyncIterator[bytes]: +- Inherit from `vision_agents.core.tts.tts.TTS`. +- Implement `stream_audio(self, text, ...)` and return a single `PcmData`. - audio_stream = self.client.text_to_speech.stream( - text=text, - voice_id=self.voice_id, - output_format=self.output_format, - model_id=self.model_id, - request_options={"chunk_size": 64000}, - ) + ```python + from vision_agents.core.edge.types import PcmData - return audio_stream + async def stream_audio(self, text: str, *_, **__) -> PcmData: + # If your SDK returns raw bytes for the whole utterance + audio_bytes = await my_sdk.tts.bytes(text=..., ...) + return PcmData.from_bytes(audio_bytes, sample_rate=16000, channels=1, format="s16") + ``` -``` +- `stop_audio` can be a no-op (the Agent controls playback): -TODO: the stop part can be generic -TODO: Track handling can be improved + ```python + async def stop_audio(self) -> None: + logger.info("TTS stop requested (no-op)") + ``` -## Testing your TTS +Sample rate is important -TOOD: no good test suite yet +- Pass the provider’s native `sample_rate`, `channels`, and `format` to `PcmData.from_bytes`. The Agent resamples to its output track, but accurate native metadata is required for correct timing and quality. + - If your SDK is streaming, buffer the audio into a single byte string and return one `PcmData`. + +Testing and examples + +- Add pytest tests at `plugins//tests/test_tts.py`. Keep them simple: assert that `stream_audio` yields `PcmData` and that `send()` emits `TTSAudioEvent`. +- Include a minimal example in `plugins//example/` (see `fish_tts_example.py`). + +Manual playback check (reusable) + +- Use the helper `vision_agents.core.tts.manual_test.manual_tts_to_wav` to generate a WAV and optionally play it with `ffplay`. +- Example inside a plugin test: + + ```python + import pytest + from vision_agents.core.tts.manual_test import manual_tts_to_wav + from vision_agents.plugins import fish + + @pytest.mark.integration + async def test_manual_tts(): + # Requires FISH_API_KEY or FISH_AUDIO_API_KEY + tts = fish.TTS() + path = await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + print("WAV written to:", path) + ``` + +Environment variables + +- Provider API keys (plugin-specific). For Fish: + - `FISH_API_KEY` or `FISH_AUDIO_API_KEY` must be set. +- Optional playback: + - Set `FFPLAY=1` and ensure `ffplay` is in PATH to auto-play the output WAV. + +Test session helper + +- To simplify event handling in tests, use `vision_agents.core.tts.testing.TTSSession`: + + ```python + from vision_agents.core.tts.testing import TTSSession + + tts = MyTTS(...) + tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) + + await tts.send("Hello") + result = await session.wait_for_result(timeout=10.0) + assert not result.errors + assert result.speeches[0] + ``` + + +References + +- See existing plugins for patterns: `plugins/fish`, `plugins/cartesia`, `plugins/elevenlabs`, `plugins/kokoro`. diff --git a/examples/01_simple_agent_example/simple_agent_example.py b/examples/01_simple_agent_example/simple_agent_example.py index 6e65e382..ed153ad2 100644 --- a/examples/01_simple_agent_example/simple_agent_example.py +++ b/examples/01_simple_agent_example/simple_agent_example.py @@ -33,9 +33,6 @@ async def start_agent() -> None: # Create a call call = agent.edge.client.video.call("default", str(uuid4())) - # Open the demo UI - await agent.edge.open_demo(call) - # Have the agent join the call/room with await agent.join(call): # Example 1: standardized simple response @@ -55,6 +52,9 @@ async def start_agent() -> None: # await agent.say("Hello, how are you?") # await asyncio.sleep(5) + # Open the demo UI + await agent.edge.open_demo(call) + await agent.simple_response("tell me something interesting in a short sentence") await agent.finish() diff --git a/plugins/aws/tests/test_aws.py b/plugins/aws/tests/test_aws.py index d1c39639..504d430e 100644 --- a/plugins/aws/tests/test_aws.py +++ b/plugins/aws/tests/test_aws.py @@ -38,12 +38,7 @@ async def llm(self) -> BedrockLLM: """Test BedrockLLM initialization with a provided client.""" llm = BedrockLLM(model="qwen.qwen3-32b-v1:0", region_name="us-east-1") if not os.environ.get("AWS_BEARER_TOKEN_BEDROCK"): - print(len(os.environ.get("_BEARER_TOKEN_BEDROCK"))) - token = os.environ.get("AWS_BEARER_TOKEN_BEDROCK") - other = os.environ.get("ANTHROPIC_API_KEY") - raise Exception( - f"Please set AWS_BEARER_TOKEN_BEDROCK {len(token)}, {type(token)}. {len(other)}, {type(other)}" - ) + pytest.skip("AWS_BEARER_TOKEN_BEDROCK not set – skipping Bedrock tests") llm._conversation = InMemoryConversation("be friendly", []) return llm diff --git a/plugins/cartesia/tests/test_tts.py b/plugins/cartesia/tests/test_tts.py index f3556979..c4141325 100644 --- a/plugins/cartesia/tests/test_tts.py +++ b/plugins/cartesia/tests/test_tts.py @@ -1,176 +1,42 @@ +from dotenv import load_dotenv import os -import asyncio -from unittest.mock import patch, MagicMock import pytest from vision_agents.plugins import cartesia -from vision_agents.core.tts.events import TTSAudioEvent -from getstream.video.rtc.audio_track import AudioStreamTrack - - -############################ -# Test utilities & fixtures -############################ - - -# A simple async iterator yielding a predefined list of byte chunks -class _AsyncBytesIterator: - def __init__(self, chunks): - self._chunks = list(chunks) - - def __aiter__(self): - return self - - async def __anext__(self): - if self._chunks: - return self._chunks.pop(0) - raise StopAsyncIteration - - -# Mock implementation of the Cartesia SDK -class MockAsyncCartesia: - """Light-weight stub mimicking the public surface used by cartesia.TTS.""" - - def __init__(self, api_key=None): - self.api_key = api_key - self.tts = MagicMock() - - # Pre-generate two fake PCM byte chunks (2000 samples each) - mock_audio = [b"\x00\x00" * 1000, b"\x00\x00" * 1000] - - self.tts.bytes = MagicMock( - side_effect=lambda *_, **__: _AsyncBytesIterator(mock_audio.copy()) - ) - - -# Re-usable audio track stub -class MockAudioTrack(AudioStreamTrack): - def __init__(self, framerate: int = 16000): - self.framerate = framerate - self.written_data: list[bytes] = [] - - async def write(self, data: bytes): - self.written_data.append(data) - return True - - -############################ -# Unit tests -############################ - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.cartesia.tts.AsyncCartesia", MockAsyncCartesia) -async def test_cartesia_tts_initialization(): - """cartesia.TTS should instantiate and store the provided api_key.""" - tts = cartesia.TTS(api_key="test-api-key") - assert tts is not None - assert tts.client.api_key == "test-api-key" - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.cartesia.tts.AsyncCartesia", MockAsyncCartesia) -@patch.dict(os.environ, {"CARTESIA_API_KEY": "env-var-api-key"}) -async def test_cartesia_tts_initialization_with_env_var(): - """When no api_key arg is supplied cartesia.TTS should read CARTESIA_API_KEY.""" - tts = cartesia.TTS() # no explicit key - assert tts.client.api_key == "env-var-api-key" - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.cartesia.tts.AsyncCartesia", MockAsyncCartesia) -async def test_cartesia_synthesize_returns_async_iterator(): - """synthesize() should yield an async iterator of PCM byte chunks.""" - tts = cartesia.TTS(api_key="test") - stream = await tts.stream_audio("Hello") - - # Must be async iterable - assert hasattr(stream, "__aiter__") - - collected = [] - async for chunk in stream: - collected.append(chunk) - - assert len(collected) == 2 - assert all(isinstance(c, (bytes, bytearray)) for c in collected) - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.cartesia.tts.AsyncCartesia", MockAsyncCartesia) -async def test_cartesia_send_writes_to_track_and_emits_event(): - tts = cartesia.TTS(api_key="test") - track = MockAudioTrack() - tts.set_output_track(track) - - received = [] - - @tts.events.subscribe - async def _on_audio(event: TTSAudioEvent): - received.append(event.audio_data) - - # Allow event subscription to be processed - await asyncio.sleep(0.01) - - await tts.send("Hello world") - - # Allow events to be processed - await asyncio.sleep(0.01) - - # Data should be forwarded to track - assert len(track.written_data) == 2 - assert track.written_data == received - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.cartesia.tts.AsyncCartesia", MockAsyncCartesia) -async def test_cartesia_invalid_framerate_raises(): - tts = cartesia.TTS(api_key="test") - bad_track = MockAudioTrack(framerate=44100) - - with pytest.raises(TypeError, match="framerate 44100"): - tts.set_output_track(bad_track) - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.cartesia.tts.AsyncCartesia", MockAsyncCartesia) -async def test_cartesia_send_without_track_raises(): - tts = cartesia.TTS(api_key="test") - - with pytest.raises(ValueError, match="No output track set"): - await tts.send("Hello, world!") - - -############################ -# Optional integration test -############################ - - -@pytest.mark.integration -@pytest.mark.asyncio -async def test_cartesia_with_real_api(): - """Integration test against Cartesia cloud – skipped if CARTESIA_API_KEY unset.""" - api_key = os.environ.get("CARTESIA_API_KEY") - if not api_key: - pytest.skip("CARTESIA_API_KEY env var not set – skipping live API test.") - - tts = cartesia.TTS(api_key=api_key) - track = MockAudioTrack() - tts.set_output_track(track) - - # Wait until we either receive audio or hit a timeout - audio_received = asyncio.Event() - - @tts.events.subscribe - async def _on_audio(event: TTSAudioEvent): - audio_received.set() - - # Allow event subscription to be processed - await asyncio.sleep(0.01) - - try: - await asyncio.wait_for(tts.send("Hello from Cartesia!"), timeout=30) - except asyncio.TimeoutError: - pytest.fail("Timed out waiting for Cartesia audio response") - - assert len(track.written_data) > 0, "No audio data received from Cartesia" +from vision_agents.core.tts.manual_test import manual_tts_to_wav +from vision_agents.core.tts.testing import TTSSession + +# Load environment variables +load_dotenv() + + +class TestCartesiaIntegration: + def tts(self) -> cartesia.TTS: # type: ignore[name-defined] + api_key = os.environ.get("CARTESIA_API_KEY") + if not api_key: + pytest.skip("CARTESIA_API_KEY env var not set – skipping live API test.") + return cartesia.TTS(api_key=api_key) + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_cartesia_with_real_api(self): + tts = self.tts() + tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) + await tts.send("Hello from Cartesia!") + result = await session.wait_for_result(timeout=30) + assert not result.errors + assert len(result.speeches) > 0 + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_cartesia_tts_convert_text_to_audio_manual_test(self): + api_key = os.environ.get("CARTESIA_API_KEY") + if not api_key: + pytest.skip( + "CARTESIA_API_KEY env var not set – skipping manual playback test." + ) + tts = self.tts() + path = await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + print("Cartesia TTS audio written to:", path) diff --git a/plugins/cartesia/vision_agents/plugins/cartesia/tts.py b/plugins/cartesia/vision_agents/plugins/cartesia/tts.py index d37acae1..e20f28af 100644 --- a/plugins/cartesia/vision_agents/plugins/cartesia/tts.py +++ b/plugins/cartesia/vision_agents/plugins/cartesia/tts.py @@ -2,7 +2,7 @@ import logging import os -from typing import Optional, cast +from typing import Optional, cast, AsyncIterator, Iterator from cartesia import AsyncCartesia from cartesia.tts import ( @@ -12,7 +12,7 @@ ) from vision_agents.core import tts -from getstream.video.rtc.audio_track import AudioStreamTrack +from vision_agents.core.edge.types import PcmData class TTS(tts.TTS): @@ -51,23 +51,10 @@ def __init__( ) self.sample_rate = sample_rate - def get_required_framerate(self) -> int: - """Get the required framerate for Cartesia TTS.""" - return self.sample_rate - - def get_required_stereo(self) -> bool: - """Get whether Cartesia TTS requires stereo audio.""" - return False # Cartesia returns mono audio - - def set_output_track(self, track: AudioStreamTrack) -> None: # noqa: D401 - if track.framerate != self.sample_rate: - raise TypeError( - f"Track framerate {track.framerate} ≠ expected {self.sample_rate}" - ) - super().set_output_track(track) - - async def stream_audio(self, text: str, *_, **__) -> bytes: # noqa: D401 - """Generate speech and yield raw PCM chunks.""" + async def stream_audio( + self, text: str, *_, **__ + ) -> PcmData | Iterator[PcmData] | AsyncIterator[PcmData]: # noqa: D401 + """Generate speech and return a stream of PcmData.""" output_format: OutputFormat_RawParams = { "container": "raw", @@ -90,11 +77,9 @@ async def stream_audio(self, text: str, *_, **__) -> bytes: # noqa: D401 voice=voice_param, ) - async def _audio_chunk_stream(): # noqa: D401 - async for chunk in response: - yield bytes(chunk) - - return _audio_chunk_stream() + return PcmData.from_response( + response, sample_rate=self.sample_rate, channels=1, format="s16" + ) async def stop_audio(self) -> None: """ @@ -104,9 +89,4 @@ async def stop_audio(self) -> None: Returns: None """ - try: - (await self.track.flush(),) - logging.info("🎤 Stopping audio track for TTS") - return - except Exception as e: - logging.error(f"Error flushing audio track: {e}") + logging.info("🎤 Cartesia TTS stop requested (no-op)") diff --git a/plugins/elevenlabs/tests/test_tts.py b/plugins/elevenlabs/tests/test_tts.py index cb7ca86b..c1b349dc 100644 --- a/plugins/elevenlabs/tests/test_tts.py +++ b/plugins/elevenlabs/tests/test_tts.py @@ -1,287 +1,35 @@ import os import pytest -import asyncio -from unittest.mock import patch, MagicMock +from vision_agents.core.tts.testing import TTSSession from vision_agents.plugins import elevenlabs -from vision_agents.core.tts.events import TTSAudioEvent, TTSErrorEvent -from getstream.video.rtc.audio_track import AudioStreamTrack - - -# Mock audio track for testing -class MockAudioTrack(AudioStreamTrack): - def __init__(self): - self.framerate = 16000 - self.written_data = [] - - async def write(self, data): - self.written_data.append(data) - return True - - -# Mock AsyncElevenLabs client for testing -class MockAsyncElevenLabsClient: - def __init__(self, api_key=None): - self.api_key = api_key - self.text_to_speech = MagicMock() - - # Create a mock audio stream that returns a few chunks of audio - mock_audio = [b"\x00\x00" * 1000, b"\x00\x00" * 1000] - - # Mock the async stream method to return an async generator - async def mock_stream(*args, **kwargs): - for chunk in mock_audio: - yield chunk - - self.text_to_speech.stream = mock_stream - - -@pytest.mark.asyncio -@patch( - "vision_agents.plugins.elevenlabs.tts.AsyncElevenLabs", MockAsyncElevenLabsClient -) -async def test_elevenlabs_tts_initialization(): - """Test that the ElevenLabs TTS initializes correctly with explicit API key.""" - tts = elevenlabs.TTS(api_key="test-api-key") - assert tts is not None - # The mock client should have the api_key attribute - assert hasattr(tts.client, "api_key") - assert tts.client.api_key == "test-api-key" - - -@pytest.mark.asyncio -@patch( - "vision_agents.plugins.elevenlabs.tts.AsyncElevenLabs", MockAsyncElevenLabsClient -) -@patch.dict(os.environ, {"ELEVENLABS_API_KEY": "env-var-api-key"}) -async def test_elevenlabs_tts_initialization_with_env_var(): - """ElevenLabsTTS should use ELEVENLABS_API_KEY when no key argument is given.""" - - tts = elevenlabs.TTS() # no explicit key provided - assert tts is not None - assert tts.client.api_key == "env-var-api-key" - - -@pytest.mark.asyncio -@patch( - "vision_agents.plugins.elevenlabs.tts.AsyncElevenLabs", MockAsyncElevenLabsClient -) -async def test_elevenlabs_tts_synthesize(): - """Test that synthesize returns an audio stream.""" - tts = elevenlabs.TTS(api_key="test-api-key") - - # Test that synthesize returns an iterator - text = "Hello, world!" - audio_stream = await tts.stream_audio(text) - - # Check that it's an async iterator - assert hasattr(audio_stream, "__aiter__") - - # Check that we can get chunks from it - chunks = [] - async for chunk in audio_stream: - chunks.append(chunk) - assert len(chunks) > 0 - assert all(isinstance(chunk, bytes) for chunk in chunks) - - -@pytest.mark.asyncio -@patch( - "vision_agents.plugins.elevenlabs.tts.AsyncElevenLabs", MockAsyncElevenLabsClient -) -async def test_elevenlabs_tts_send(): - """Test that send writes audio to the track and emits events.""" - tts = elevenlabs.TTS(api_key="test-api-key") - - # Create a mock audio track - track = MockAudioTrack() - tts.set_output_track(track) - - # Track emitted audio events - emitted_audio = [] - - @tts.events.subscribe - async def on_audio(event: TTSAudioEvent): - emitted_audio.append(event.audio_data) - - # Allow event subscription to be processed - await asyncio.sleep(0.01) - - # Send text to the TTS - text = "Hello, world!" - await tts.send(text) - - # Allow events to be processed - await asyncio.sleep(0.01) - - # Check that audio was written to the track - assert len(track.written_data) > 0 - - # Check that audio events were emitted - assert len(emitted_audio) > 0 - assert emitted_audio == track.written_data - - -@pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) -async def test_elevenlabs_tts_send_without_track(): - """Test that sending without setting a track raises an error.""" - tts = elevenlabs.TTS(api_key="test-api-key") - - # Sending without setting a track should raise ValueError - with pytest.raises(ValueError, match="No output track set"): - await tts.send("Hello, world!") - - -@pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) -async def test_elevenlabs_tts_invalid_framerate(): - """Test that setting a track with invalid framerate raises an error.""" - tts = elevenlabs.TTS(api_key="test-api-key") - - # Create a mock audio track with invalid framerate - invalid_track = MagicMock(spec=AudioStreamTrack) - invalid_track.framerate = 44100 - - # Setting the invalid track should raise TypeError - with pytest.raises(TypeError, match="Invalid framerate"): - tts.set_output_track(invalid_track) - - -@pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) -async def test_elevenlabs_tts_with_custom_client(): - """Test that ElevenLabs TTS can be initialized with a custom client.""" - # Create a custom mock client - custom_client = MockAsyncElevenLabsClient(api_key="custom-api-key") - - # Initialize TTS with the custom client - tts = elevenlabs.TTS(client=custom_client) - - # Verify that the custom client is used - assert tts.client is custom_client - assert tts.client.api_key == "custom-api-key" - - -@pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) -async def test_elevenlabs_tts_stop_method(): - """Test that the stop method properly flushes the audio track.""" - tts = elevenlabs.TTS(api_key="test-api-key") - - # Create a mock audio track with flush method - track = MockAudioTrack() - track.flush = MagicMock(return_value=asyncio.Future()) - track.flush.return_value.set_result(None) - - tts.set_output_track(track) - - # Call stop method - await tts.stop_audio() - - # Verify that flush was called on the track - track.flush.assert_called_once() - - -@pytest.mark.asyncio -@patch("elevenlabs.client.AsyncElevenLabs", MockAsyncElevenLabsClient) -async def test_elevenlabs_tts_stop_method_handles_exceptions(): - """Test that the stop method handles flush exceptions gracefully.""" - tts = elevenlabs.TTS(api_key="test-api-key") - - # Create a mock audio track with flush method that raises an exception - track = MockAudioTrack() - track.flush = MagicMock(side_effect=Exception("Flush error")) - - tts.set_output_track(track) - - # Call stop method - should not raise an exception - await tts.stop_audio() - - # Verify that flush was called on the track - track.flush.assert_called_once() - - -@pytest.mark.integration -@pytest.mark.asyncio -async def test_elevenlabs_with_real_api(): - """ - Integration test with the real ElevenLabs API. - - This test uses the actual ElevenLabs API with the - ELEVENLABS_API_KEY environment variable. - It will be skipped if the environment variable is not set. - - To set up the ELEVENLABS_API_KEY: - 1. Sign up for an ElevenLabs account at https://elevenlabs.io - 2. Create an API key in your ElevenLabs dashboard - 3. Add to your .env file: ELEVENLABS_API_KEY=your_api_key_here - """ - # Check if the required API key is available - api_key = os.environ.get("ELEVENLABS_API_KEY") - - # Skip the test if the ELEVENLABS_API_KEY environment variable is not set - if not api_key: - pytest.skip( - "ELEVENLABS_API_KEY environment variable not set. Add it to your .env file." - ) - - # Create a real ElevenLabs TTS instance with the API key explicitly set - tts = elevenlabs.TTS(api_key=api_key) - - # Create a mock audio track to capture the output - track = MockAudioTrack() - tts.set_output_track(track) - - # Track audio events - audio_received = asyncio.Event() - received_chunks = [] - - @tts.events.subscribe - async def on_audio(event: TTSAudioEvent): - received_chunks.append(event.audio_data) - audio_received.set() - - # Track API errors - api_errors = [] - - @tts.events.subscribe - async def on_error(event: TTSErrorEvent): - api_errors.append(event.error) - audio_received.set() # Unblock the waiting - - # Allow event subscriptions to be processed - await asyncio.sleep(0.01) - - try: - # Use a short text to minimize API usage - text = "This is a test of the ElevenLabs text-to-speech API." - - # Send the text to generate speech - send_task = asyncio.create_task(tts.send(text)) - - # Wait for either audio or an error +from vision_agents.core.tts.manual_test import manual_tts_to_wav + + +class TestElevenLabsIntegration: + @pytest.fixture + def tts(self) -> elevenlabs.TTS: + api_key = os.environ.get("ELEVENLABS_API_KEY") + if not api_key: + pytest.skip( + "ELEVENLABS_API_KEY environment variable not set. Add it to your .env file." + ) + return elevenlabs.TTS(api_key=api_key) + + @pytest.mark.integration + async def test_elevenlabs_with_real_api(self, tts): + tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) try: - await asyncio.wait_for(audio_received.wait(), timeout=15.0) - except asyncio.TimeoutError: - # Cancel the task if it's taking too long - send_task.cancel() - pytest.fail("No audio or error received within timeout") - - # Check if we received any API errors - if api_errors: - pytest.skip(f"API error received: {api_errors[0]}") - - # Try to ensure the send task completes - try: - await send_task + await tts.send("This is a test of the ElevenLabs text-to-speech API.") + result = await session.wait_for_result(timeout=15.0) except Exception as e: - pytest.skip(f"Exception during TTS generation: {e}") + pytest.skip(f"Unexpected error in ElevenLabs test: {e}") + + assert not result.errors + assert len(result.speeches) > 0 - # Verify that we received audio data - assert len(received_chunks) > 0, "No audio chunks were received" - except Exception as e: - pytest.skip(f"Unexpected error in ElevenLabs test: {e}") - finally: - # Event handlers are automatically cleaned up when the TTS instance is destroyed - pass + @pytest.mark.integration + async def test_elevenlabs_tts_convert_text_to_audio_manual_test(self, tts): + path = await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + print("ElevenLabs TTS audio written to:", path) diff --git a/plugins/elevenlabs/vision_agents/plugins/elevenlabs/tts.py b/plugins/elevenlabs/vision_agents/plugins/elevenlabs/tts.py index a84f2869..15a862e7 100644 --- a/plugins/elevenlabs/vision_agents/plugins/elevenlabs/tts.py +++ b/plugins/elevenlabs/vision_agents/plugins/elevenlabs/tts.py @@ -1,11 +1,10 @@ import logging +import os +from typing import AsyncIterator, Iterator, Optional -from vision_agents.core import tts from elevenlabs.client import AsyncElevenLabs -from getstream.video.rtc.audio_track import AudioStreamTrack -from typing import AsyncIterator, Optional - -import os +from vision_agents.core import tts +from vision_agents.core.edge.types import PcmData class TTS(tts.TTS): @@ -37,20 +36,9 @@ def __init__( self.model_id = model_id self.output_format = "pcm_16000" - def get_required_framerate(self) -> int: - """Get the required framerate for ElevenLabs TTS.""" - return 16000 - - def get_required_stereo(self) -> bool: - """Get whether ElevenLabs TTS requires stereo audio.""" - return False # ElevenLabs returns mono audio - - def set_output_track(self, track: AudioStreamTrack) -> None: - if track.framerate != 16000: - raise TypeError("Invalid framerate, audio track only supports 16000") - super().set_output_track(track) - - async def stream_audio(self, text: str, *_, **__) -> AsyncIterator[bytes]: + async def stream_audio( + self, text: str, *_, **__ + ) -> PcmData | Iterator[PcmData] | AsyncIterator[PcmData]: """ Convert text to speech using ElevenLabs API. @@ -69,7 +57,9 @@ async def stream_audio(self, text: str, *_, **__) -> AsyncIterator[bytes]: request_options={"chunk_size": 64000}, ) - return audio_stream + return PcmData.from_response( + audio_stream, sample_rate=16000, channels=1, format="s16" + ) async def stop_audio(self) -> None: """ @@ -79,11 +69,4 @@ async def stop_audio(self) -> None: Returns: None """ - if self.track is not None: - try: - await self.track.flush() - logging.info("🎤 Stopping audio track for TTS") - except Exception as e: - logging.error(f"Error flushing audio track: {e}") - else: - logging.warning("No audio track to stop") + logging.info("🎤 ElevenLabs TTS stop requested (no-op)") diff --git a/plugins/fish/tests/test_tts.py b/plugins/fish/tests/test_tts.py index cbdede10..31e07b43 100644 --- a/plugins/fish/tests/test_tts.py +++ b/plugins/fish/tests/test_tts.py @@ -1,97 +1,37 @@ -import asyncio +import os import pytest from dotenv import load_dotenv from vision_agents.plugins import fish -from vision_agents.core.tts.events import TTSAudioEvent, TTSErrorEvent -from getstream.video.rtc.audio_track import AudioStreamTrack +from vision_agents.core.tts.manual_test import manual_tts_to_wav +from vision_agents.core.tts.testing import TTSSession # Load environment variables load_dotenv() -# Audio track for capturing test output -class MockAudioTrack(AudioStreamTrack): - def __init__(self, framerate: int = 16000): - self.framerate = framerate - self.written_data = [] - async def write(self, data: bytes): - self.written_data.append(data) - return True +class TestFishTTS: + @pytest.fixture + def tts(self) -> fish.TTS: + return fish.TTS() + @pytest.mark.integration + async def test_fish_tts_convert_text_to_audio_manual_test(self, tts: fish.TTS): + if not (os.environ.get("FISH_API_KEY") or os.environ.get("FISH_AUDIO_API_KEY")): + pytest.skip( + "FISH_API_KEY/FISH_AUDIO_API_KEY not set; skipping manual playback test." + ) + await manual_tts_to_wav(tts, sample_rate=16000, channels=1) -@pytest.mark.integration -async def test_fish_tts_convert_text_to_audio(): - """ - Integration test with the real Fish Audio API. - - This test uses the actual Fish Audio API with the - FISH_AUDIO_API_KEY environment variable. - It will be skipped if the environment variable is not set. - - To set up the FISH_AUDIO_API_KEY: - 1. Sign up for a Fish Audio account at https://fish.audio - 2. Create an API key in your Fish Audio dashboard - 3. Add to your .env file: FISH_AUDIO_API_KEY=your_api_key_here - """ - - - # Create a real Fish Audio TTS instance - tts = fish.TTS() - - # Create an audio track to capture the output - track = MockAudioTrack() - tts.set_output_track(track) - - # Track audio events - audio_received = asyncio.Event() - received_chunks = [] - - @tts.events.subscribe - async def on_audio(event: TTSAudioEvent): - received_chunks.append(event.audio_data) - audio_received.set() - - # Track API errors - api_errors = [] - - @tts.events.subscribe - async def on_error(event: TTSErrorEvent): - api_errors.append(event.error) - audio_received.set() # Unblock the waiting - - # Allow event subscriptions to be processed - await asyncio.sleep(0.01) - - try: - # Use a short text to minimize API usage + @pytest.mark.integration + async def test_fish_tts_convert_text_to_audio(self, tts: fish.TTS): + tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) text = "Hello from Fish Audio." - - # Send the text to generate speech - send_task = asyncio.create_task(tts.send(text)) - - # Wait for either audio or an error - try: - await asyncio.wait_for(audio_received.wait(), timeout=15.0) - except asyncio.TimeoutError: - # Cancel the task if it's taking too long - send_task.cancel() - pytest.fail("No audio or error received within timeout") - - # Check if we received any API errors - if api_errors: - pytest.skip(f"API error received: {api_errors[0]}") - - # Try to ensure the send task completes - try: - await send_task - except Exception as e: - pytest.skip(f"Exception during TTS generation: {e}") - - # Verify that we received audio data - assert len(received_chunks) > 0, "No audio chunks were received" - assert len(track.written_data) > 0, "No audio data was written to track" - except Exception as e: - pytest.skip(f"Unexpected error in Fish Audio test: {e}") + await tts.send(text) + await session.wait_for_result(timeout=15.0) + + assert not session.errors + assert len(session.speeches) > 0 diff --git a/plugins/fish/vision_agents/plugins/fish/tts.py b/plugins/fish/vision_agents/plugins/fish/tts.py index b34b7398..78bd76f3 100644 --- a/plugins/fish/vision_agents/plugins/fish/tts.py +++ b/plugins/fish/vision_agents/plugins/fish/tts.py @@ -1,10 +1,10 @@ import logging import os -from typing import AsyncIterator, Optional +from typing import AsyncIterator, Iterator, Optional from fish_audio_sdk import Session, TTSRequest -from getstream.video.rtc.audio_track import AudioStreamTrack from vision_agents.core import tts +from vision_agents.core.edge.types import PcmData logger = logging.getLogger(__name__) @@ -12,7 +12,7 @@ class TTS(tts.TTS): """ Fish Audio Text-to-Speech implementation. - + Fish Audio provides high-quality, multilingual text-to-speech synthesis with support for voice cloning via reference audio. @@ -22,7 +22,7 @@ class TTS(tts.TTS): def __init__( self, api_key: Optional[str] = None, - reference_id: Optional[str] = None, + reference_id: Optional[str] = "03397b4c4be74759b72533b663fbd001", base_url: Optional[str] = None, client: Optional[Session] = None, ): @@ -39,7 +39,10 @@ def __init__( super().__init__(provider_name="fish") if not api_key: - api_key = os.environ.get("FISH_API_KEY") + # Support both env names for compatibility + api_key = os.environ.get("FISH_API_KEY") or os.environ.get( + "FISH_AUDIO_API_KEY" + ) if client is not None: self.client = client @@ -50,35 +53,9 @@ def __init__( self.reference_id = reference_id - # Fish Audio typically outputs at 44100 Hz, but we'll use 16000 for compatibility - # Note: You may need to adjust this based on Fish Audio's actual output - self.output_framerate = 16000 - - def get_required_framerate(self) -> int: - """Get the required framerate for Fish Audio TTS.""" - return self.output_framerate - - def get_required_stereo(self) -> bool: - """Get whether Fish Audio TTS requires stereo audio.""" - return False # Fish Audio typically returns mono audio - - def set_output_track(self, track: AudioStreamTrack) -> None: - """ - Set the output audio track. - - Args: - track: The audio track to output to. - - Raises: - TypeError: If the track framerate doesn't match requirements. - """ - if track.framerate != self.output_framerate: - raise TypeError( - f"Invalid framerate, audio track only supports {self.output_framerate}" - ) - super().set_output_track(track) - - async def stream_audio(self, text: str, *_, **kwargs) -> AsyncIterator[bytes]: + async def stream_audio( + self, text: str, *_, **kwargs + ) -> PcmData | Iterator[PcmData] | AsyncIterator[PcmData]: """ Convert text to speech using Fish Audio API. @@ -91,36 +68,35 @@ async def stream_audio(self, text: str, *_, **kwargs) -> AsyncIterator[bytes]: """ # Build the TTS request tts_request_kwargs = {"text": text} - + # Add reference_id if configured if self.reference_id: tts_request_kwargs["reference_id"] = self.reference_id - + # Allow overriding via kwargs (e.g., for dynamic reference audio) tts_request_kwargs.update(kwargs) - - tts_request = TTSRequest(format="pcm", sample_rate=16000, normalize=True,reference_id="03397b4c4be74759b72533b663fbd001", **tts_request_kwargs) - # Stream audio from Fish Audio - audio_stream = self.client.tts.awaitable(tts_request) + tts_request = TTSRequest( + format="pcm", + sample_rate=16000, + normalize=True, + **tts_request_kwargs, + ) - return audio_stream + # Stream audio from Fish Audio; let PcmData normalize response types + stream = self.client.tts.awaitable(tts_request) + return PcmData.from_response( + stream, sample_rate=16000, channels=1, format="s16" + ) async def stop_audio(self) -> None: """ Clears the queue and stops playing audio. - + This method can be used manually or under the hood in response to turn events. Returns: None """ - if self.track is not None: - try: - await self.track.flush() - logger.info("🎤 Stopping audio track for Fish Audio TTS") - except Exception as e: - logger.error(f"Error flushing audio track: {e}") - else: - logger.warning("No audio track to stop") - + # No internal output track to flush; agent manages playback + logger.info("🎤 Fish TTS stop requested (no-op)") diff --git a/plugins/kokoro/tests/test_tts.py b/plugins/kokoro/tests/test_tts.py index bb3c5c26..c19d6b03 100644 --- a/plugins/kokoro/tests/test_tts.py +++ b/plugins/kokoro/tests/test_tts.py @@ -1,164 +1,18 @@ -from unittest.mock import patch, MagicMock -import asyncio - -import numpy as np import pytest +from vision_agents.core.tts.manual_test import manual_tts_to_wav -from vision_agents.plugins import kokoro -from vision_agents.core.tts.events import TTSAudioEvent -from getstream.video.rtc.audio_track import AudioStreamTrack - - -############################ -# Test utilities & fixtures -############################ - - -class MockAudioTrack(AudioStreamTrack): - def __init__(self, framerate: int = 24_000): - self.framerate = framerate - self.written_data: list[bytes] = [] - - async def write(self, data: bytes): - self.written_data.append(data) - return True - - -class _MockKPipeline: # noqa: D401 - """Very small stub that mimics ``kokoro.KPipeline`` callable behaviour.""" - - def __init__(self, *_, **__): - pass - - def __call__(self, text, *, voice, speed, split_pattern): # noqa: D401 - # Produce two mini 20 ms chunks of silence at 24 kHz - blank = np.zeros(480, dtype=np.float32) # 480 samples @ 24 kHz = 20 ms - for _ in range(2): - yield text, voice, blank - - -############################ -# Unit-tests -############################ - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.kokoro.tts.KPipeline", _MockKPipeline) -async def test_kokoro_tts_initialization(): - tts = kokoro.TTS() - assert tts is not None - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.kokoro.tts.KPipeline", _MockKPipeline) -async def test_kokoro_synthesize_returns_iterator(): - tts = kokoro.TTS() - stream = await tts.stream_audio("Hello") - - # Should be an async iterator (list of bytes) - chunks = [] - async for chunk in stream: - chunks.append(chunk) - - assert len(chunks) == 2 - assert all(isinstance(c, (bytes, bytearray)) for c in chunks) - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.kokoro.tts.KPipeline", _MockKPipeline) -async def test_kokoro_send_writes_and_emits(): - tts = kokoro.TTS() - track = MockAudioTrack() - tts.set_output_track(track) - - received = [] - - @tts.events.subscribe - async def _on_audio(event: TTSAudioEvent): - # Extract the audio data from the event - if hasattr(event, "audio_data") and event.audio_data is not None: - received.append(event.audio_data) - else: - received.append(b"") - - # Allow event subscription to be processed - await asyncio.sleep(0.01) - - await tts.send("Hello world") - - # Allow events to be processed - await asyncio.sleep(0.01) - - assert len(track.written_data) == 2 - assert track.written_data == received - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.kokoro.tts.KPipeline", _MockKPipeline) -async def test_kokoro_invalid_framerate(): - tts = kokoro.TTS() - bad_track = MockAudioTrack(framerate=16_000) - - with pytest.raises(TypeError): - tts.set_output_track(bad_track) - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.kokoro.tts.KPipeline", _MockKPipeline) -async def test_kokoro_send_without_track(): - tts = kokoro.TTS() - with pytest.raises(ValueError): - await tts.send("Hi") - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.kokoro.tts.KPipeline", _MockKPipeline) -async def test_kokoro_tts_with_custom_client(): - """Test that Kokoro TTS can be initialized with a custom client.""" - # Create a custom mock client - custom_client = _MockKPipeline() - - # Initialize TTS with the custom client - tts = kokoro.TTS(client=custom_client) - - # Verify that the custom client is used - assert tts.client is custom_client - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.kokoro.tts.KPipeline", _MockKPipeline) -async def test_kokoro_tts_stop_method(): - """Test that the stop method properly flushes the audio track.""" - tts = kokoro.TTS() - - # Create a mock audio track with flush method - track = MockAudioTrack() - track.flush = MagicMock(return_value=asyncio.Future()) - track.flush.return_value.set_result(None) - - tts.set_output_track(track) - - # Call stop method - await tts.stop_audio() - - # Verify that flush was called on the track - track.flush.assert_called_once() - - -@pytest.mark.asyncio -@patch("vision_agents.plugins.kokoro.tts.KPipeline", _MockKPipeline) -async def test_kokoro_tts_stop_method_handles_exceptions(): - """Test that the stop method handles flush exceptions gracefully.""" - tts = kokoro.TTS() - - # Create a mock audio track with flush method that raises an exception - track = MockAudioTrack() - track.flush = MagicMock(side_effect=Exception("Flush error")) - tts.set_output_track(track) +class TestKokoroIntegration: + @pytest.fixture + def tts(self): # returns kokoro TTS if available + try: + import kokoro # noqa: F401 + except Exception: + pytest.skip("kokoro package not installed; skipping manual playback test.") + from vision_agents.plugins import kokoro as kokoro_plugin - # Call stop method - should not raise an exception - await tts.stop_audio() + return kokoro_plugin.TTS() - # Verify that flush was called on the track - track.flush.assert_called_once() + @pytest.mark.integration + async def test_kokoro_tts_convert_text_to_audio_manual_test(self, tts): + await manual_tts_to_wav(tts, sample_rate=24000, channels=1) diff --git a/plugins/kokoro/vision_agents/plugins/kokoro/tts.py b/plugins/kokoro/vision_agents/plugins/kokoro/tts.py index 82483f8c..0bbbdea3 100644 --- a/plugins/kokoro/vision_agents/plugins/kokoro/tts.py +++ b/plugins/kokoro/vision_agents/plugins/kokoro/tts.py @@ -2,12 +2,12 @@ import asyncio import logging +from typing import AsyncIterator, Iterator, List, Optional import numpy as np -from typing import AsyncIterator, List, Optional from vision_agents.core import tts -from getstream.video.rtc.audio_track import AudioStreamTrack +from vision_agents.core.edge.types import PcmData try: from kokoro import KPipeline # type: ignore @@ -44,22 +44,9 @@ def __init__( self.sample_rate = sample_rate self.client = client if client is not None else self._pipeline - def get_required_framerate(self) -> int: - """Get the required framerate for Kokoro TTS.""" - return self.sample_rate - - def get_required_stereo(self) -> bool: - """Get whether Kokoro TTS requires stereo audio.""" - return False # Kokoro returns mono audio - - def set_output_track(self, track: AudioStreamTrack) -> None: # noqa: D401 - if track.framerate != self.sample_rate: - raise TypeError( - f"Invalid framerate {track.framerate}, Kokoro requires {self.sample_rate} Hz" - ) - super().set_output_track(track) - - async def stream_audio(self, text: str, *_, **__) -> AsyncIterator[bytes]: # noqa: D401 + async def stream_audio( + self, text: str, *_, **__ + ) -> PcmData | Iterator[PcmData] | AsyncIterator[PcmData]: # noqa: D401 loop = asyncio.get_event_loop() chunks: List[bytes] = await loop.run_in_executor( None, lambda: list(self._generate_chunks(text)) @@ -67,7 +54,9 @@ async def stream_audio(self, text: str, *_, **__) -> AsyncIterator[bytes]: # no async def _aiter(): for chunk in chunks: - yield chunk + yield PcmData.from_bytes( + chunk, sample_rate=self.sample_rate, channels=1, format="s16" + ) return _aiter() @@ -76,11 +65,7 @@ async def stop_audio(self) -> None: Clears the queue and stops playing audio. """ - try: - await self.track.flush() - return - except Exception as e: - logging.error(f"Error flushing audio track: {e}") + logging.info("🎤 Kokoro TTS stop requested (no-op)") def _generate_chunks(self, text: str): for _gs, _ps, audio in self._pipeline( diff --git a/tests/test_tts_base.py b/tests/test_tts_base.py new file mode 100644 index 00000000..6cce0f87 --- /dev/null +++ b/tests/test_tts_base.py @@ -0,0 +1,215 @@ +import asyncio +from typing import AsyncIterator, Iterator, List + +import pytest + +from vision_agents.core.tts.tts import TTS as BaseTTS +from vision_agents.core.tts.events import ( + TTSAudioEvent, + TTSErrorEvent, + TTSSynthesisStartEvent, + TTSSynthesisCompleteEvent, +) +from vision_agents.core.edge.types import PcmData + + +class DummyTTSBytesSingle(BaseTTS): + async def stream_audio(self, text: str, *_, **__) -> bytes: + # 16-bit PCM mono (s16), 100 samples -> 200 bytes + self._native_sample_rate = 16000 + self._native_channels = 1 + return b"\x00\x00" * 100 + + async def stop_audio(self) -> None: # pragma: no cover - noop + return None + + +class DummyTTSBytesAsync(BaseTTS): + async def stream_audio(self, text: str, *_, **__) -> AsyncIterator[bytes]: + self._native_sample_rate = 16000 + self._native_channels = 1 + + async def _agen(): + # Unaligned chunk sizes to test aggregator + yield b"\x00\x00" * 33 + b"\x00" # odd size + yield b"\x00\x00" * 10 + + return _agen() + + async def stop_audio(self) -> None: # pragma: no cover - noop + return None + + +class DummyTTSIterSync(BaseTTS): + async def stream_audio(self, text: str, *_, **__) -> Iterator[bytes]: + self._native_sample_rate = 16000 + self._native_channels = 1 + return iter([b"\x00\x00" * 50, b"\x00\x00" * 25]) + + async def stop_audio(self) -> None: # pragma: no cover - noop + return None + + +class DummyTTSPcmStereoToMono(BaseTTS): + async def stream_audio(self, text: str, *_, **__) -> PcmData: + # 2 channels interleaved: 100 frames (per channel) -> 200 samples -> 400 bytes + frames = b"\x01\x00\x01\x00" * 100 # L(1), R(1) + pcm = PcmData.from_bytes(frames, sample_rate=16000, channels=2, format="s16") + return pcm + + async def stop_audio(self) -> None: # pragma: no cover - noop + return None + + +class DummyTTSPcmResample(BaseTTS): + async def stream_audio(self, text: str, *_, **__) -> PcmData: + # 16k mono, 200 samples (duration = 200/16000 s) + data = b"\x00\x00" * 200 + pcm = PcmData.from_bytes(data, sample_rate=16000, channels=1, format="s16") + return pcm + + async def stop_audio(self) -> None: # pragma: no cover - noop + return None + + +class DummyTTSError(BaseTTS): + async def stream_audio(self, text: str, *_, **__): + raise RuntimeError("boom") + + async def stop_audio(self) -> None: # pragma: no cover - noop + return None + + +@pytest.mark.asyncio +async def test_tts_bytes_single_emits_events_and_bytes(): + tts = DummyTTSBytesSingle() + tts.set_output_format(sample_rate=16000, channels=1) + + events: List[type] = [] + audio_chunks: List[bytes] = [] + + @tts.events.subscribe + async def _on_start(ev: TTSSynthesisStartEvent): + events.append(TTSSynthesisStartEvent) + + @tts.events.subscribe + async def _on_audio(ev: TTSAudioEvent): + events.append(TTSAudioEvent) + if ev.audio_data: + audio_chunks.append(ev.audio_data) + + @tts.events.subscribe + async def _on_complete(ev: TTSSynthesisCompleteEvent): + events.append(TTSSynthesisCompleteEvent) + + await asyncio.sleep(0.01) + await tts.send("hello") + await tts.events.wait() + + # Expect start -> audio -> complete + assert TTSSynthesisStartEvent in events + assert TTSAudioEvent in events + assert TTSSynthesisCompleteEvent in events + assert len(audio_chunks) == 1 + # audio event sample_rate/channels reflect desired output + assert audio_chunks[0] is not None + + +@pytest.mark.asyncio +async def test_tts_bytes_async_aggregates_and_emits(): + tts = DummyTTSBytesAsync() + tts.set_output_format(sample_rate=16000, channels=1) + + chunks: List[bytes] = [] + + @tts.events.subscribe + async def _on_audio(ev: TTSAudioEvent): + if isinstance(ev, TTSAudioEvent) and ev.audio_data: + chunks.append(ev.audio_data) + + await asyncio.sleep(0.01) + await tts.send("hi") + await tts.events.wait() + + # Should emit at least one aligned chunk + assert len(chunks) >= 1 + # Sum of bytes equals or exceeds first unaligned chunk (due to padding/next chunk) + assert sum(len(c) for c in chunks) >= 2 * 33 # approx check + + +@pytest.mark.asyncio +async def test_tts_iter_sync_emits_multiple_chunks(): + tts = DummyTTSIterSync() + tts.set_output_format(sample_rate=16000, channels=1) + + chunks: List[bytes] = [] + + @tts.events.subscribe + async def _on_audio(ev: TTSAudioEvent): + if ev.audio_data: + chunks.append(ev.audio_data) + + await asyncio.sleep(0.01) + await tts.send("hello") + await tts.events.wait() + assert len(chunks) >= 2 + + +@pytest.mark.asyncio +async def test_tts_stereo_to_mono_halves_bytes(): + tts = DummyTTSPcmStereoToMono() + # desired mono, same sample rate + tts.set_output_format(sample_rate=16000, channels=1) + + emitted: List[bytes] = [] + + @tts.events.subscribe + async def _on_audio(ev: TTSAudioEvent): + if ev.audio_data: + emitted.append(ev.audio_data) + + await asyncio.sleep(0.01) + await tts.send("x") + await tts.events.wait() + assert len(emitted) == 1 + # Original interleaved data length was 400 bytes; mono should be ~200 bytes + assert 180 <= len(emitted[0]) <= 220 + + +@pytest.mark.asyncio +async def test_tts_resample_changes_size_reasonably(): + tts = DummyTTSPcmResample() + # Resample from 16k -> 8k, mono + tts.set_output_format(sample_rate=8000, channels=1) + + emitted: List[bytes] = [] + + @tts.events.subscribe + async def _on_audio(ev: TTSAudioEvent): + if ev.audio_data: + emitted.append(ev.audio_data) + + await asyncio.sleep(0.01) + await tts.send("y") + await tts.events.wait() + assert len(emitted) == 1 + # Input had 200 samples (400 bytes); at 8k this should be roughly half + assert 150 <= len(emitted[0]) <= 250 + + +@pytest.mark.asyncio +async def test_tts_error_emits_and_raises(): + tts = DummyTTSError() + + errors: List[TTSErrorEvent] = [] + + @tts.events.subscribe + async def _on_error(ev: TTSErrorEvent): + if isinstance(ev, TTSErrorEvent): + errors.append(ev) + + await asyncio.sleep(0.01) + with pytest.raises(RuntimeError): + await tts.send("boom") + await tts.events.wait() + assert len(errors) >= 1 From 512d874fb79855a41e6449082898594b622d80ee Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 11:31:13 +0200 Subject: [PATCH 02/15] stt ai intructions --- docs/ai/instructions/ai-tts.md | 72 +++++----------------------------- 1 file changed, 9 insertions(+), 63 deletions(-) diff --git a/docs/ai/instructions/ai-tts.md b/docs/ai/instructions/ai-tts.md index 8cc7552e..5962e9cd 100644 --- a/docs/ai/instructions/ai-tts.md +++ b/docs/ai/instructions/ai-tts.md @@ -2,15 +2,15 @@ Build a TTS plugin that streams audio and emits events. Keep it minimal and follow the project’s layout conventions. -What to create (PEP 420 structure) +## What to create -- PEP 420: Do NOT add `__init__.py` in plugin folders. Use this layout: +- Make sure to follow PEP 420: Do NOT add `__init__.py` in plugin folders. Use this layout: - `plugins//pyproject.toml` (depends on `vision-agents`) - `plugins//vision_agents/plugins//tts.py` - `plugins//tests/test_tts.py` (pytest tests at plugin root) - `plugins//example/` (optional, see `plugins/fish/example/fish_tts_example.py`) -Implementation essentials +## Implementation essentials - Inherit from `vision_agents.core.tts.tts.TTS`. - Implement `stream_audio(self, text, ...)` and return a single `PcmData`. @@ -19,71 +19,17 @@ Implementation essentials from vision_agents.core.edge.types import PcmData async def stream_audio(self, text: str, *_, **__) -> PcmData: - # If your SDK returns raw bytes for the whole utterance audio_bytes = await my_sdk.tts.bytes(text=..., ...) + # sample_rate, channels and format depend on what the STT model returns return PcmData.from_bytes(audio_bytes, sample_rate=16000, channels=1, format="s16") ``` -- `stop_audio` can be a no-op (the Agent controls playback): +- `stop_audio` can be a no-op - ```python - async def stop_audio(self) -> None: - logger.info("TTS stop requested (no-op)") - ``` - -Sample rate is important - -- Pass the provider’s native `sample_rate`, `channels`, and `format` to `PcmData.from_bytes`. The Agent resamples to its output track, but accurate native metadata is required for correct timing and quality. - - If your SDK is streaming, buffer the audio into a single byte string and return one `PcmData`. - -Testing and examples +## Testing and examples +- Look at `plugins/fish/tests/test_fish_tts.py` as a reference of what tests for a TTS plugins should look like - Add pytest tests at `plugins//tests/test_tts.py`. Keep them simple: assert that `stream_audio` yields `PcmData` and that `send()` emits `TTSAudioEvent`. +- Do not write spec tests with mocks, this is usually not necessary +- Make to write at least a couple integration tests, use `TTSSession` to avoid boiler-plate code in testing - Include a minimal example in `plugins//example/` (see `fish_tts_example.py`). - -Manual playback check (reusable) - -- Use the helper `vision_agents.core.tts.manual_test.manual_tts_to_wav` to generate a WAV and optionally play it with `ffplay`. -- Example inside a plugin test: - - ```python - import pytest - from vision_agents.core.tts.manual_test import manual_tts_to_wav - from vision_agents.plugins import fish - - @pytest.mark.integration - async def test_manual_tts(): - # Requires FISH_API_KEY or FISH_AUDIO_API_KEY - tts = fish.TTS() - path = await manual_tts_to_wav(tts, sample_rate=16000, channels=1) - print("WAV written to:", path) - ``` - -Environment variables - -- Provider API keys (plugin-specific). For Fish: - - `FISH_API_KEY` or `FISH_AUDIO_API_KEY` must be set. -- Optional playback: - - Set `FFPLAY=1` and ensure `ffplay` is in PATH to auto-play the output WAV. - -Test session helper - -- To simplify event handling in tests, use `vision_agents.core.tts.testing.TTSSession`: - - ```python - from vision_agents.core.tts.testing import TTSSession - - tts = MyTTS(...) - tts.set_output_format(sample_rate=16000, channels=1) - session = TTSSession(tts) - - await tts.send("Hello") - result = await session.wait_for_result(timeout=10.0) - assert not result.errors - assert result.speeches[0] - ``` - - -References - -- See existing plugins for patterns: `plugins/fish`, `plugins/cartesia`, `plugins/elevenlabs`, `plugins/kokoro`. From e5e0cf542b7019e157daf86487ce5a9732a03fea Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 12:58:18 +0200 Subject: [PATCH 03/15] openai tts plugin --- docs/ai/instructions/ai-tts.md | 11 +++- plugins/cartesia/tests/test_tts.py | 19 +++---- plugins/elevenlabs/tests/test_tts.py | 11 ++-- plugins/fish/tests/test_fish_tts.py | 9 +--- plugins/openai/tests/test_tts_openai.py | 31 +++++++++++ .../vision_agents/plugins/openai/__init__.py | 5 +- .../vision_agents/plugins/openai/tts.py | 51 +++++++++++++++++++ 7 files changed, 107 insertions(+), 30 deletions(-) create mode 100644 plugins/openai/tests/test_tts_openai.py create mode 100644 plugins/openai/vision_agents/plugins/openai/tts.py diff --git a/docs/ai/instructions/ai-tts.md b/docs/ai/instructions/ai-tts.md index 5962e9cd..f3376fbc 100644 --- a/docs/ai/instructions/ai-tts.md +++ b/docs/ai/instructions/ai-tts.md @@ -24,7 +24,16 @@ Build a TTS plugin that streams audio and emits events. Keep it minimal and foll return PcmData.from_bytes(audio_bytes, sample_rate=16000, channels=1, format="s16") ``` -- `stop_audio` can be a no-op +- `stop_audio` can be a no-op + +## __init__ + +The plugin constructor should: + +1. Rely on env vars to fetch credentials +2. export kwargs that allow developers to pass important params to the model itself (eg. model name, voice ID, API URL, ...) +3. if applicable the model or client instance +4. have defaults for all params when possible so that ENV var is enough ## Testing and examples diff --git a/plugins/cartesia/tests/test_tts.py b/plugins/cartesia/tests/test_tts.py index c4141325..a0c5e4f5 100644 --- a/plugins/cartesia/tests/test_tts.py +++ b/plugins/cartesia/tests/test_tts.py @@ -2,6 +2,7 @@ import os import pytest +import pytest_asyncio from vision_agents.plugins import cartesia from vision_agents.core.tts.manual_test import manual_tts_to_wav @@ -12,6 +13,7 @@ class TestCartesiaIntegration: + @pytest_asyncio.fixture def tts(self) -> cartesia.TTS: # type: ignore[name-defined] api_key = os.environ.get("CARTESIA_API_KEY") if not api_key: @@ -19,24 +21,15 @@ def tts(self) -> cartesia.TTS: # type: ignore[name-defined] return cartesia.TTS(api_key=api_key) @pytest.mark.integration - @pytest.mark.asyncio - async def test_cartesia_with_real_api(self): - tts = self.tts() + async def test_cartesia_with_real_api(self, tts): tts.set_output_format(sample_rate=16000, channels=1) session = TTSSession(tts) await tts.send("Hello from Cartesia!") + result = await session.wait_for_result(timeout=30) assert not result.errors assert len(result.speeches) > 0 @pytest.mark.integration - @pytest.mark.asyncio - async def test_cartesia_tts_convert_text_to_audio_manual_test(self): - api_key = os.environ.get("CARTESIA_API_KEY") - if not api_key: - pytest.skip( - "CARTESIA_API_KEY env var not set – skipping manual playback test." - ) - tts = self.tts() - path = await manual_tts_to_wav(tts, sample_rate=16000, channels=1) - print("Cartesia TTS audio written to:", path) + async def test_cartesia_tts_convert_text_to_audio_manual_test(self, tts): + await manual_tts_to_wav(tts, sample_rate=16000, channels=1) diff --git a/plugins/elevenlabs/tests/test_tts.py b/plugins/elevenlabs/tests/test_tts.py index c1b349dc..46a3416c 100644 --- a/plugins/elevenlabs/tests/test_tts.py +++ b/plugins/elevenlabs/tests/test_tts.py @@ -1,5 +1,6 @@ import os import pytest +import pytest_asyncio from vision_agents.core.tts.testing import TTSSession from vision_agents.plugins import elevenlabs @@ -7,7 +8,7 @@ class TestElevenLabsIntegration: - @pytest.fixture + @pytest_asyncio.fixture def tts(self) -> elevenlabs.TTS: api_key = os.environ.get("ELEVENLABS_API_KEY") if not api_key: @@ -20,11 +21,9 @@ def tts(self) -> elevenlabs.TTS: async def test_elevenlabs_with_real_api(self, tts): tts.set_output_format(sample_rate=16000, channels=1) session = TTSSession(tts) - try: - await tts.send("This is a test of the ElevenLabs text-to-speech API.") - result = await session.wait_for_result(timeout=15.0) - except Exception as e: - pytest.skip(f"Unexpected error in ElevenLabs test: {e}") + + await tts.send("This is a test of the ElevenLabs text-to-speech API.") + result = await session.wait_for_result(timeout=15.0) assert not result.errors assert len(result.speeches) > 0 diff --git a/plugins/fish/tests/test_fish_tts.py b/plugins/fish/tests/test_fish_tts.py index 31e07b43..e23ff4c0 100644 --- a/plugins/fish/tests/test_fish_tts.py +++ b/plugins/fish/tests/test_fish_tts.py @@ -1,6 +1,5 @@ -import os - import pytest +import pytest_asyncio from dotenv import load_dotenv from vision_agents.plugins import fish @@ -12,16 +11,12 @@ class TestFishTTS: - @pytest.fixture + @pytest_asyncio.fixture def tts(self) -> fish.TTS: return fish.TTS() @pytest.mark.integration async def test_fish_tts_convert_text_to_audio_manual_test(self, tts: fish.TTS): - if not (os.environ.get("FISH_API_KEY") or os.environ.get("FISH_AUDIO_API_KEY")): - pytest.skip( - "FISH_API_KEY/FISH_AUDIO_API_KEY not set; skipping manual playback test." - ) await manual_tts_to_wav(tts, sample_rate=16000, channels=1) @pytest.mark.integration diff --git a/plugins/openai/tests/test_tts_openai.py b/plugins/openai/tests/test_tts_openai.py new file mode 100644 index 00000000..05c6746b --- /dev/null +++ b/plugins/openai/tests/test_tts_openai.py @@ -0,0 +1,31 @@ +import os +import pytest +import pytest_asyncio + +from vision_agents.plugins import openai as openai_plugin +from vision_agents.core.tts.testing import TTSSession +from vision_agents.core.tts.manual_test import manual_tts_to_wav + + +class TestOpenAITTSIntegration: + @pytest_asyncio.fixture + async def tts(self) -> openai_plugin.TTS: # type: ignore[name-defined] + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not set") + return openai_plugin.TTS(api_key=api_key) + + @pytest.mark.integration + async def test_openai_tts_speech(self, tts: openai_plugin.TTS): + tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) + + await tts.send("Hello from OpenAI TTS") + + result = await session.wait_for_result(timeout=20.0) + assert not result.errors + assert len(result.speeches) > 0 + + @pytest.mark.integration + async def test_openai_tts_manual_wav(self, tts: openai_plugin.TTS): + await manual_tts_to_wav(tts, sample_rate=16000, channels=1) diff --git a/plugins/openai/vision_agents/plugins/openai/__init__.py b/plugins/openai/vision_agents/plugins/openai/__init__.py index be4ca2e4..cdf7ed65 100644 --- a/plugins/openai/vision_agents/plugins/openai/__init__.py +++ b/plugins/openai/vision_agents/plugins/openai/__init__.py @@ -1,6 +1,5 @@ - from .openai_llm import OpenAILLM as LLM from .openai_realtime import Realtime +from .tts import TTS -__all__ = ["Realtime", "LLM"] - +__all__ = ["Realtime", "LLM", "TTS"] diff --git a/plugins/openai/vision_agents/plugins/openai/tts.py b/plugins/openai/vision_agents/plugins/openai/tts.py new file mode 100644 index 00000000..d2bcd3f6 --- /dev/null +++ b/plugins/openai/vision_agents/plugins/openai/tts.py @@ -0,0 +1,51 @@ +import os +from typing import Optional + +from openai import AsyncOpenAI + +from vision_agents.core.tts.tts import TTS as BaseTTS +from vision_agents.core.edge.types import PcmData + + +class TTS(BaseTTS): + """OpenAI Text-to-Speech implementation. + + Uses OpenAI's TTS models to synthesize speech. + Docs: https://platform.openai.com/docs/guides/text-to-speech + """ + + def __init__( + self, + *, + api_key: Optional[str] = None, + model: str = "gpt-4o-mini-tts", + voice: str = "alloy", + client: Optional[AsyncOpenAI] = None, + ) -> None: + super().__init__(provider_name="openai_tts") + api_key = api_key or os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY not set") + self.client = client or AsyncOpenAI(api_key=api_key) + self.model = model + self.voice = voice + + async def stream_audio(self, text: str, *_, **__) -> PcmData: + """Synthesize the entire speech to a single PCM buffer. + + Base TTS handles resampling and event emission. + """ + resp = await self.client.audio.speech.create( + model=self.model, + voice=self.voice, + input=text, + response_format="pcm", + ) + + return PcmData.from_bytes( + resp.content, sample_rate=24_000, channels=1, format="s16" + ) + + async def stop_audio(self) -> None: + # No internal playback queue; agent manages output track + return None From ff9ebed898bd695db32f183271c9843849142fe2 Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 13:33:07 +0200 Subject: [PATCH 04/15] AWS Polly with TTS support --- plugins/aws/README.md | 22 ++++- plugins/aws/example/aws_polly_tts_example.py | 16 ++++ plugins/aws/tests/test_tts.py | 48 ++++++++++ .../aws/vision_agents/plugins/aws/__init__.py | 3 +- plugins/aws/vision_agents/plugins/aws/tts.py | 92 +++++++++++++++++++ 5 files changed, 177 insertions(+), 4 deletions(-) create mode 100644 plugins/aws/example/aws_polly_tts_example.py create mode 100644 plugins/aws/tests/test_tts.py create mode 100644 plugins/aws/vision_agents/plugins/aws/tts.py diff --git a/plugins/aws/README.md b/plugins/aws/README.md index e82e21f4..ae1617ea 100644 --- a/plugins/aws/README.md +++ b/plugins/aws/README.md @@ -1,6 +1,6 @@ # AWS Plugin for Vision Agents -AWS (Bedrock) LLM integration for Vision Agents framework with support for both standard and realtime interactions. +AWS (Bedrock) LLM integration for Vision Agents framework with support for both standard and realtime interactions. Includes AWS Polly TTS. ## Installation @@ -32,7 +32,7 @@ The full example is available in example/aws_qwen_example.py Nova sonic audio realtime STS is also supported: -```python +```python agent = Agent( edge=getstream.Edge(), agent_user=User(name="Story Teller AI"), @@ -41,6 +41,21 @@ agent = Agent( ) ``` +### Polly TTS Usage + +```python +from vision_agents.plugins import aws +from vision_agents.core.tts.manual_test import manual_tts_to_wav +import asyncio + +async def main(): + # For PCM, AWS Polly supports 8000 or 16000 Hz + tts = aws.TTS(voice_id="Joanna", sample_rate=16000) + await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + +asyncio.run(main()) +``` + ## Running the examples Create a `.env` file, or cp .env.example to .env and fill in @@ -52,8 +67,9 @@ STREAM_API_SECRET=your_stream_api_secret_here AWS_BEARER_TOKEN_BEDROCK= AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= +AWS_REGION=us-east-1 FAL_KEY= CARTESIA_API_KEY= DEEPGRAM_API_KEY= -``` \ No newline at end of file +``` diff --git a/plugins/aws/example/aws_polly_tts_example.py b/plugins/aws/example/aws_polly_tts_example.py new file mode 100644 index 00000000..c93f1880 --- /dev/null +++ b/plugins/aws/example/aws_polly_tts_example.py @@ -0,0 +1,16 @@ +import asyncio +import os +from dotenv import load_dotenv + +from vision_agents.plugins.aws import TTS +from vision_agents.core.tts.manual_test import manual_tts_to_wav + + +async def main(): + load_dotenv() + tts = TTS(voice_id=os.environ.get("AWS_POLLY_VOICE", "Joanna")) + await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/plugins/aws/tests/test_tts.py b/plugins/aws/tests/test_tts.py new file mode 100644 index 00000000..9c9f1bbb --- /dev/null +++ b/plugins/aws/tests/test_tts.py @@ -0,0 +1,48 @@ +import os +import pytest +import pytest_asyncio +from dotenv import load_dotenv + +from vision_agents.plugins import aws as aws_plugin +from vision_agents.core.tts.testing import TTSSession +from vision_agents.core.tts.manual_test import manual_tts_to_wav + + +load_dotenv() + + +def _has_aws_creds() -> bool: + return any( + os.environ.get(k) + for k in ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_PROFILE", + "AWS_WEB_IDENTITY_TOKEN_FILE", + ) + ) + + +class TestAWSPollyTTS: + @pytest_asyncio.fixture + async def tts(self) -> aws_plugin.TTS: # type: ignore[name-defined] + if not _has_aws_creds(): + pytest.skip("AWS credentials not set – skipping Polly TTS tests") + # Region can be overridden via AWS_REGION/AWS_DEFAULT_REGION + return aws_plugin.TTS(voice_id=os.environ.get("AWS_POLLY_VOICE", "Joanna")) + + @pytest.mark.integration + async def test_aws_polly_tts_speech(self, tts: aws_plugin.TTS): + tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) + + await tts.send("Hello from AWS Polly TTS") + + result = await session.wait_for_result(timeout=30.0) + assert not result.errors + assert len(result.speeches) > 0 + + @pytest.mark.integration + async def test_aws_polly_tts_manual_wav(self, tts: aws_plugin.TTS): + await manual_tts_to_wav(tts, sample_rate=16000, channels=1) diff --git a/plugins/aws/vision_agents/plugins/aws/__init__.py b/plugins/aws/vision_agents/plugins/aws/__init__.py index aec3c840..aafbcf67 100644 --- a/plugins/aws/vision_agents/plugins/aws/__init__.py +++ b/plugins/aws/vision_agents/plugins/aws/__init__.py @@ -1,4 +1,5 @@ from .aws_llm import BedrockLLM as LLM from .aws_realtime import Realtime +from .tts import TTS -__all__ = ["LLM", "Realtime"] +__all__ = ["LLM", "Realtime", "TTS"] diff --git a/plugins/aws/vision_agents/plugins/aws/tts.py b/plugins/aws/vision_agents/plugins/aws/tts.py new file mode 100644 index 00000000..21f697fd --- /dev/null +++ b/plugins/aws/vision_agents/plugins/aws/tts.py @@ -0,0 +1,92 @@ +import os +from typing import Optional, Union, Iterator, AsyncIterator, List, Any + +import boto3 + +from vision_agents.core.tts.tts import TTS as BaseTTS +from vision_agents.core.edge.types import PcmData + + +class TTS(BaseTTS): + """AWS Polly Text-to-Speech implementation. + + Follows AWS Polly SynthesizeSpeech API: + - OutputFormat is set to 'pcm' (signed 16-bit little-endian, mono) + - SampleRate must be one of {'8000','16000'} for PCM + - TextType can be 'text' or 'ssml' (auto-detected unless overridden) + - Optional Engine ('standard' or 'neural'), LanguageCode, LexiconNames + + Credentials are resolved via standard AWS SDK chain (env vars, profiles, roles). + """ + + def __init__( + self, + *, + region_name: Optional[str] = None, + voice_id: str = "Joanna", + text_type: Optional[str] = "text", # 'text' | 'ssml' + engine: Optional[str] = None, # 'standard' | 'neural' + language_code: Optional[str] = None, + lexicon_names: Optional[List[str]] = None, + client: Optional[Any] = None, + ) -> None: + super().__init__(provider_name="aws_polly") + self.region_name = ( + region_name + or os.environ.get("AWS_REGION") + or os.environ.get("AWS_DEFAULT_REGION") + or "us-east-1" + ) + self.voice_id = voice_id + + if engine is not None and engine not in ("standard", "neural"): + raise ValueError("engine must be 'standard' or 'neural' if provided") + if text_type is not None and text_type not in ("text", "ssml"): + raise ValueError("text_type must be 'text' or 'ssml' if provided") + + self.text_type = text_type + self.engine = engine + self.language_code = language_code + self.lexicon_names = lexicon_names + self._client = client + + @property + def client(self): + if self._client is None: + self._client = boto3.client("polly", region_name=self.region_name) + return self._client + + async def stream_audio( + self, text: str, *_, **__ + ) -> Union[PcmData, Iterator[PcmData], AsyncIterator[PcmData]]: + """Synthesize the entire speech to a single PCM buffer. + + Returns PcmData with s16 format and the configured sample rate. + """ + + params = { + "Text": text, + "OutputFormat": "pcm", + "VoiceId": self.voice_id, + "SampleRate": "16000", + "TextType": self.text_type, + } + + if self.engine is not None: + params["Engine"] = self.engine + if self.language_code is not None: + params["LanguageCode"] = self.language_code + if self.lexicon_names: + params["LexiconNames"] = self.lexicon_names # type: ignore[assignment] + + # Polly returns a StreamingBody for AudioStream + resp = self.client.synthesize_speech(**params) + + audio_bytes = resp["AudioStream"].read() + + return PcmData.from_bytes( + audio_bytes, sample_rate=16000, channels=1, format="s16" + ) + + async def stop_audio(self) -> None: + return None From 9d670f40b2eacd03c75fd1d42813c5b862b54e38 Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 13:49:34 +0200 Subject: [PATCH 05/15] cleanup code --- tests/test_tts_base.py | 110 ++++++++--------------------------------- 1 file changed, 21 insertions(+), 89 deletions(-) diff --git a/tests/test_tts_base.py b/tests/test_tts_base.py index 6cce0f87..07fea2ec 100644 --- a/tests/test_tts_base.py +++ b/tests/test_tts_base.py @@ -1,16 +1,10 @@ -import asyncio -from typing import AsyncIterator, Iterator, List +from typing import AsyncIterator, Iterator import pytest from vision_agents.core.tts.tts import TTS as BaseTTS -from vision_agents.core.tts.events import ( - TTSAudioEvent, - TTSErrorEvent, - TTSSynthesisStartEvent, - TTSSynthesisCompleteEvent, -) from vision_agents.core.edge.types import PcmData +from vision_agents.core.tts.testing import TTSSession class DummyTTSBytesSingle(BaseTTS): @@ -80,136 +74,74 @@ async def stop_audio(self) -> None: # pragma: no cover - noop return None -@pytest.mark.asyncio async def test_tts_bytes_single_emits_events_and_bytes(): tts = DummyTTSBytesSingle() tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) - events: List[type] = [] - audio_chunks: List[bytes] = [] - - @tts.events.subscribe - async def _on_start(ev: TTSSynthesisStartEvent): - events.append(TTSSynthesisStartEvent) - - @tts.events.subscribe - async def _on_audio(ev: TTSAudioEvent): - events.append(TTSAudioEvent) - if ev.audio_data: - audio_chunks.append(ev.audio_data) - - @tts.events.subscribe - async def _on_complete(ev: TTSSynthesisCompleteEvent): - events.append(TTSSynthesisCompleteEvent) - - await asyncio.sleep(0.01) await tts.send("hello") await tts.events.wait() + result = await session.wait_for_result(timeout=1.0) - # Expect start -> audio -> complete - assert TTSSynthesisStartEvent in events - assert TTSAudioEvent in events - assert TTSSynthesisCompleteEvent in events - assert len(audio_chunks) == 1 - # audio event sample_rate/channels reflect desired output - assert audio_chunks[0] is not None + assert result.started + assert result.completed + assert len(session.speeches) == 1 + assert session.speeches[0] is not None -@pytest.mark.asyncio async def test_tts_bytes_async_aggregates_and_emits(): tts = DummyTTSBytesAsync() tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) - chunks: List[bytes] = [] - - @tts.events.subscribe - async def _on_audio(ev: TTSAudioEvent): - if isinstance(ev, TTSAudioEvent) and ev.audio_data: - chunks.append(ev.audio_data) - - await asyncio.sleep(0.01) await tts.send("hi") await tts.events.wait() - # Should emit at least one aligned chunk - assert len(chunks) >= 1 - # Sum of bytes equals or exceeds first unaligned chunk (due to padding/next chunk) - assert sum(len(c) for c in chunks) >= 2 * 33 # approx check + assert len(session.speeches) >= 1 + assert sum(len(c) for c in session.speeches) >= 2 * 33 # approx check -@pytest.mark.asyncio async def test_tts_iter_sync_emits_multiple_chunks(): tts = DummyTTSIterSync() tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) - chunks: List[bytes] = [] - - @tts.events.subscribe - async def _on_audio(ev: TTSAudioEvent): - if ev.audio_data: - chunks.append(ev.audio_data) - - await asyncio.sleep(0.01) await tts.send("hello") await tts.events.wait() - assert len(chunks) >= 2 + assert len(session.speeches) >= 2 -@pytest.mark.asyncio async def test_tts_stereo_to_mono_halves_bytes(): tts = DummyTTSPcmStereoToMono() # desired mono, same sample rate tts.set_output_format(sample_rate=16000, channels=1) + session = TTSSession(tts) - emitted: List[bytes] = [] - - @tts.events.subscribe - async def _on_audio(ev: TTSAudioEvent): - if ev.audio_data: - emitted.append(ev.audio_data) - - await asyncio.sleep(0.01) await tts.send("x") await tts.events.wait() - assert len(emitted) == 1 + assert len(session.speeches) == 1 # Original interleaved data length was 400 bytes; mono should be ~200 bytes - assert 180 <= len(emitted[0]) <= 220 + assert 180 <= len(session.speeches[0]) <= 220 -@pytest.mark.asyncio async def test_tts_resample_changes_size_reasonably(): tts = DummyTTSPcmResample() # Resample from 16k -> 8k, mono tts.set_output_format(sample_rate=8000, channels=1) + session = TTSSession(tts) - emitted: List[bytes] = [] - - @tts.events.subscribe - async def _on_audio(ev: TTSAudioEvent): - if ev.audio_data: - emitted.append(ev.audio_data) - - await asyncio.sleep(0.01) await tts.send("y") await tts.events.wait() - assert len(emitted) == 1 + assert len(session.speeches) == 1 # Input had 200 samples (400 bytes); at 8k this should be roughly half - assert 150 <= len(emitted[0]) <= 250 + assert 150 <= len(session.speeches[0]) <= 250 -@pytest.mark.asyncio async def test_tts_error_emits_and_raises(): tts = DummyTTSError() + session = TTSSession(tts) - errors: List[TTSErrorEvent] = [] - - @tts.events.subscribe - async def _on_error(ev: TTSErrorEvent): - if isinstance(ev, TTSErrorEvent): - errors.append(ev) - - await asyncio.sleep(0.01) with pytest.raises(RuntimeError): await tts.send("boom") await tts.events.wait() - assert len(errors) >= 1 + assert len(session.errors) >= 1 From f497422461a3ebe4c46ee1c674f0880357d347ab Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 14:05:03 +0200 Subject: [PATCH 06/15] check for blocking send, fix AWS --- agents-core/vision_agents/core/tts/testing.py | 79 +++++++++++++++++++ agents-core/vision_agents/core/tts/tts.py | 9 --- docs/ai/instructions/ai-tests.md | 14 +++- docs/ai/instructions/ai-tts.md | 9 +++ plugins/aws/tests/test_tts.py | 6 +- plugins/aws/vision_agents/plugins/aws/tts.py | 9 ++- plugins/cartesia/tests/test_tts.py | 8 +- plugins/elevenlabs/tests/test_tts.py | 8 +- plugins/fish/tests/test_fish_tts.py | 8 +- plugins/openai/tests/test_tts_openai.py | 6 +- 10 files changed, 136 insertions(+), 20 deletions(-) diff --git a/agents-core/vision_agents/core/tts/testing.py b/agents-core/vision_agents/core/tts/testing.py index 1c291d0a..6153f515 100644 --- a/agents-core/vision_agents/core/tts/testing.py +++ b/agents-core/vision_agents/core/tts/testing.py @@ -79,3 +79,82 @@ async def wait_for_result(self, timeout: float = 10.0) -> TTSResult: started=self._started, completed=self._completed, ) + + +@dataclass +class EventLoopProbeResult: + ticks: int + elapsed_ms: float + max_gap_ms: float + + +async def _probe_event_loop_while(coro, interval: float = 0.01) -> EventLoopProbeResult: + """Run a coroutine while probing event loop responsiveness. + + Spawns a ticker task that sleeps for `interval` and counts ticks, + measuring the maximum observed gap between wakeups while `coro` runs. + + Returns probe metrics once `coro` completes. + """ + loop = asyncio.get_running_loop() + stop = asyncio.Event() + ticks = 0 + max_gap = 0.0 + last = loop.time() + + async def _ticker(): + nonlocal ticks, max_gap, last + while not stop.is_set(): + await asyncio.sleep(interval) + now = loop.time() + gap = (now - last) * 1000.0 + if gap > max_gap: + max_gap = gap + ticks += 1 + last = now + + start = loop.time() + ticker_task = asyncio.create_task(_ticker()) + try: + await coro + finally: + stop.set() + try: + await asyncio.wait_for(ticker_task, timeout=1.0) + except asyncio.TimeoutError: + ticker_task.cancel() + elapsed_ms = (loop.time() - start) * 1000.0 + return EventLoopProbeResult(ticks=ticks, elapsed_ms=elapsed_ms, max_gap_ms=max_gap) + + +async def assert_tts_send_non_blocking( + tts: TTS, + text: str = "Hello from non-blocking test", + *, + interval: float = 0.01, + min_observation_ms: float = 50.0, + min_expected_ticks: int = 2, +) -> EventLoopProbeResult: + """Assert that `tts.send(text)` does not block the event loop. + + This helper runs `tts.send(text)` while probing the event loop at a small + interval. If the call takes at least `min_observation_ms`, we require the + probe to tick at least `min_expected_ticks`. A zero or very low tick count + indicates the event loop was blocked (e.g., sync SDK call without executor). + + Returns the probe metrics for optional additional assertions. + """ + # Ensure output format is set so send() can emit properly even if unused + try: + tts.set_output_format(sample_rate=16000, channels=1) + except Exception: + pass + + probe = await _probe_event_loop_while(tts.send(text), interval=interval) + + if probe.elapsed_ms >= min_observation_ms: + assert probe.ticks >= min_expected_ticks, ( + f"tts.send blocked event loop: ticks={probe.ticks}, elapsed_ms={probe.elapsed_ms:.1f}. It looks like the stream_audio method is blocking the event loop." + ) + # If call was too fast, we don't strictly assert; return metrics for info + return probe diff --git a/agents-core/vision_agents/core/tts/tts.py b/agents-core/vision_agents/core/tts/tts.py index dc63dc50..9db3509c 100644 --- a/agents-core/vision_agents/core/tts/tts.py +++ b/agents-core/vision_agents/core/tts/tts.py @@ -14,7 +14,6 @@ TTSErrorEvent, ) from vision_agents.core.events import ( - PluginInitializedEvent, PluginClosedEvent, AudioFormat, ) @@ -69,14 +68,6 @@ def __init__(self, provider_name: Optional[str] = None): self._native_sample_rate: int = 16000 self._native_channels: int = 1 self._native_format: AudioFormat = AudioFormat.PCM_S16 - self.events.send( - PluginInitializedEvent( - session_id=self.session_id, - plugin_name=self.provider_name, - plugin_type="TTS", - provider=self.provider_name, - ) - ) def set_output_format( self, diff --git a/docs/ai/instructions/ai-tests.md b/docs/ai/instructions/ai-tests.md index c6080095..2d9dbbd2 100644 --- a/docs/ai/instructions/ai-tests.md +++ b/docs/ai/instructions/ai-tests.md @@ -6,4 +6,16 @@ This project uses uv to manage Python and its dependencies so when you run tests Extend from BaseTest -Store data for fixtures in tests/test_assets/... \ No newline at end of file +Store data for fixtures in tests/test_assets/... + +Non-blocking checks + +- TTS plugins must not block the event loop inside `stream_audio`. Use the helper in `vision_agents.core.tts.testing`: + + ```python + from vision_agents.core.tts.testing import assert_tts_send_non_blocking + + @pytest.mark.integration + async def test_tts_non_blocking(tts): + await assert_tts_send_non_blocking(tts, "Hello") + ``` diff --git a/docs/ai/instructions/ai-tts.md b/docs/ai/instructions/ai-tts.md index f3376fbc..3a8eff69 100644 --- a/docs/ai/instructions/ai-tts.md +++ b/docs/ai/instructions/ai-tts.md @@ -41,4 +41,13 @@ The plugin constructor should: - Add pytest tests at `plugins//tests/test_tts.py`. Keep them simple: assert that `stream_audio` yields `PcmData` and that `send()` emits `TTSAudioEvent`. - Do not write spec tests with mocks, this is usually not necessary - Make to write at least a couple integration tests, use `TTSSession` to avoid boiler-plate code in testing +- Verify your implementation does not block the event loop. Import and call `assert_tts_send_non_blocking`: + + ```python + from vision_agents.core.tts.testing import assert_tts_send_non_blocking + + @pytest.mark.integration + async def test_provider_non_blocking(tts): + await assert_tts_send_non_blocking(tts, "Hello from TTS") + ``` - Include a minimal example in `plugins//example/` (see `fish_tts_example.py`). diff --git a/plugins/aws/tests/test_tts.py b/plugins/aws/tests/test_tts.py index 9c9f1bbb..a881d02e 100644 --- a/plugins/aws/tests/test_tts.py +++ b/plugins/aws/tests/test_tts.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv from vision_agents.plugins import aws as aws_plugin -from vision_agents.core.tts.testing import TTSSession +from vision_agents.core.tts.testing import TTSSession, assert_tts_send_non_blocking from vision_agents.core.tts.manual_test import manual_tts_to_wav @@ -46,3 +46,7 @@ async def test_aws_polly_tts_speech(self, tts: aws_plugin.TTS): @pytest.mark.integration async def test_aws_polly_tts_manual_wav(self, tts: aws_plugin.TTS): await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + + @pytest.mark.integration + async def test_aws_polly_tts_non_blocking(self, tts: aws_plugin.TTS): + await assert_tts_send_non_blocking(tts, "Hello from AWS Polly TTS") diff --git a/plugins/aws/vision_agents/plugins/aws/tts.py b/plugins/aws/vision_agents/plugins/aws/tts.py index 21f697fd..ea45b170 100644 --- a/plugins/aws/vision_agents/plugins/aws/tts.py +++ b/plugins/aws/vision_agents/plugins/aws/tts.py @@ -1,4 +1,6 @@ +import asyncio import os +from functools import partial from typing import Optional, Union, Iterator, AsyncIterator, List, Any import boto3 @@ -79,8 +81,11 @@ async def stream_audio( if self.lexicon_names: params["LexiconNames"] = self.lexicon_names # type: ignore[assignment] - # Polly returns a StreamingBody for AudioStream - resp = self.client.synthesize_speech(**params) + # this is necessary to avoid blocking the event loop, I will first write a failing test + loop = asyncio.get_running_loop() + resp = await loop.run_in_executor( + None, partial(self.client.synthesize_speech, **params) + ) audio_bytes = resp["AudioStream"].read() diff --git a/plugins/cartesia/tests/test_tts.py b/plugins/cartesia/tests/test_tts.py index a0c5e4f5..b90f5bf3 100644 --- a/plugins/cartesia/tests/test_tts.py +++ b/plugins/cartesia/tests/test_tts.py @@ -6,7 +6,7 @@ from vision_agents.plugins import cartesia from vision_agents.core.tts.manual_test import manual_tts_to_wav -from vision_agents.core.tts.testing import TTSSession +from vision_agents.core.tts.testing import TTSSession, assert_tts_send_non_blocking # Load environment variables load_dotenv() @@ -14,7 +14,7 @@ class TestCartesiaIntegration: @pytest_asyncio.fixture - def tts(self) -> cartesia.TTS: # type: ignore[name-defined] + async def tts(self) -> cartesia.TTS: # type: ignore[name-defined] api_key = os.environ.get("CARTESIA_API_KEY") if not api_key: pytest.skip("CARTESIA_API_KEY env var not set – skipping live API test.") @@ -33,3 +33,7 @@ async def test_cartesia_with_real_api(self, tts): @pytest.mark.integration async def test_cartesia_tts_convert_text_to_audio_manual_test(self, tts): await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + + @pytest.mark.integration + async def test_cartesia_tts_non_blocking(self, tts: cartesia.TTS): + await assert_tts_send_non_blocking(tts, "Hello from Cartesia!") diff --git a/plugins/elevenlabs/tests/test_tts.py b/plugins/elevenlabs/tests/test_tts.py index 46a3416c..829e6be9 100644 --- a/plugins/elevenlabs/tests/test_tts.py +++ b/plugins/elevenlabs/tests/test_tts.py @@ -2,14 +2,14 @@ import pytest import pytest_asyncio -from vision_agents.core.tts.testing import TTSSession +from vision_agents.core.tts.testing import TTSSession, assert_tts_send_non_blocking from vision_agents.plugins import elevenlabs from vision_agents.core.tts.manual_test import manual_tts_to_wav class TestElevenLabsIntegration: @pytest_asyncio.fixture - def tts(self) -> elevenlabs.TTS: + async def tts(self) -> elevenlabs.TTS: api_key = os.environ.get("ELEVENLABS_API_KEY") if not api_key: pytest.skip( @@ -32,3 +32,7 @@ async def test_elevenlabs_with_real_api(self, tts): async def test_elevenlabs_tts_convert_text_to_audio_manual_test(self, tts): path = await manual_tts_to_wav(tts, sample_rate=16000, channels=1) print("ElevenLabs TTS audio written to:", path) + + @pytest.mark.integration + async def test_elevenlabs_tts_non_blocking(self, tts): + await assert_tts_send_non_blocking(tts, "This should not block the event loop.") diff --git a/plugins/fish/tests/test_fish_tts.py b/plugins/fish/tests/test_fish_tts.py index e23ff4c0..a1450201 100644 --- a/plugins/fish/tests/test_fish_tts.py +++ b/plugins/fish/tests/test_fish_tts.py @@ -4,7 +4,7 @@ from vision_agents.plugins import fish from vision_agents.core.tts.manual_test import manual_tts_to_wav -from vision_agents.core.tts.testing import TTSSession +from vision_agents.core.tts.testing import TTSSession, assert_tts_send_non_blocking # Load environment variables load_dotenv() @@ -12,7 +12,7 @@ class TestFishTTS: @pytest_asyncio.fixture - def tts(self) -> fish.TTS: + async def tts(self) -> fish.TTS: return fish.TTS() @pytest.mark.integration @@ -30,3 +30,7 @@ async def test_fish_tts_convert_text_to_audio(self, tts: fish.TTS): assert not session.errors assert len(session.speeches) > 0 + + @pytest.mark.integration + async def test_fish_tts_non_blocking(self, tts: fish.TTS): + await assert_tts_send_non_blocking(tts, "Hello from Fish Audio.") diff --git a/plugins/openai/tests/test_tts_openai.py b/plugins/openai/tests/test_tts_openai.py index 05c6746b..d00ae4ba 100644 --- a/plugins/openai/tests/test_tts_openai.py +++ b/plugins/openai/tests/test_tts_openai.py @@ -3,7 +3,7 @@ import pytest_asyncio from vision_agents.plugins import openai as openai_plugin -from vision_agents.core.tts.testing import TTSSession +from vision_agents.core.tts.testing import TTSSession, assert_tts_send_non_blocking from vision_agents.core.tts.manual_test import manual_tts_to_wav @@ -29,3 +29,7 @@ async def test_openai_tts_speech(self, tts: openai_plugin.TTS): @pytest.mark.integration async def test_openai_tts_manual_wav(self, tts: openai_plugin.TTS): await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + + @pytest.mark.integration + async def test_openai_tts_non_blocking(self, tts: openai_plugin.TTS): + await assert_tts_send_non_blocking(tts, "Hello from OpenAI TTS") From 9e9366ad517ff04645093e0e38ed0f7e9bf99ba1 Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 14:23:54 +0200 Subject: [PATCH 07/15] small fixes --- .../vision_agents/core/agents/agents.py | 15 ++-- agents-core/vision_agents/core/tts/tts.py | 5 -- docs/ai/instructions/ai-tts.md | 4 +- .../simple_agent_example.py | 5 +- tests/test_tts_base.py | 76 ------------------- 5 files changed, 9 insertions(+), 96 deletions(-) diff --git a/agents-core/vision_agents/core/agents/agents.py b/agents-core/vision_agents/core/agents/agents.py index 4094dd13..d358fab5 100644 --- a/agents-core/vision_agents/core/agents/agents.py +++ b/agents-core/vision_agents/core/agents/agents.py @@ -2,7 +2,7 @@ import logging import time import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from uuid import uuid4 import aiortc @@ -313,8 +313,6 @@ async def on_realtime_agent_speech_transcription( async def _on_tts_audio(event: TTSAudioEvent): try: if self._audio_track and event.audio_data: - from typing import Any, cast - track_any = cast(Any, self._audio_track) await track_any.write(event.audio_data) except Exception as e: @@ -1043,13 +1041,10 @@ def _prepare_rtc(self): # Inform TTS of desired output format so it can resample accordingly if self.tts: channels = 2 if stereo else 1 - try: - self.tts.set_output_format( - sample_rate=framerate, - channels=channels, - ) - except Exception as e: - self.logger.warning(f"Failed to set TTS output format: {e}") + self.tts.set_output_format( + sample_rate=framerate, + channels=channels, + ) # Set up video track if video publishers are available if self.publish_video: diff --git a/agents-core/vision_agents/core/tts/tts.py b/agents-core/vision_agents/core/tts/tts.py index 9db3509c..0b8f7340 100644 --- a/agents-core/vision_agents/core/tts/tts.py +++ b/agents-core/vision_agents/core/tts/tts.py @@ -89,11 +89,6 @@ def set_output_format( self._desired_channels = int(channels) self._desired_format = audio_format - # Backwards-compatibility helper if any subclass still calls it - def set_native_format(self, sample_rate: int, channels: int = 1) -> None: - self._native_sample_rate = int(sample_rate) - self._native_channels = int(channels) - def _normalize_to_pcm(self, item: Union[bytes, bytearray, PcmData, Any]) -> PcmData: """Normalize a chunk to PcmData using the native provider format.""" if isinstance(item, PcmData): diff --git a/docs/ai/instructions/ai-tts.md b/docs/ai/instructions/ai-tts.md index 3a8eff69..78a768f8 100644 --- a/docs/ai/instructions/ai-tts.md +++ b/docs/ai/instructions/ai-tts.md @@ -20,7 +20,7 @@ Build a TTS plugin that streams audio and emits events. Keep it minimal and foll async def stream_audio(self, text: str, *_, **__) -> PcmData: audio_bytes = await my_sdk.tts.bytes(text=..., ...) - # sample_rate, channels and format depend on what the STT model returns + # sample_rate, channels and format depend on what the TTS model returns return PcmData.from_bytes(audio_bytes, sample_rate=16000, channels=1, format="s16") ``` @@ -40,7 +40,7 @@ The plugin constructor should: - Look at `plugins/fish/tests/test_fish_tts.py` as a reference of what tests for a TTS plugins should look like - Add pytest tests at `plugins//tests/test_tts.py`. Keep them simple: assert that `stream_audio` yields `PcmData` and that `send()` emits `TTSAudioEvent`. - Do not write spec tests with mocks, this is usually not necessary -- Make to write at least a couple integration tests, use `TTSSession` to avoid boiler-plate code in testing +- Make sure to write at least a couple of integration tests, use `TTSSession` to avoid boiler-plate code in testing - Verify your implementation does not block the event loop. Import and call `assert_tts_send_non_blocking`: ```python diff --git a/examples/01_simple_agent_example/simple_agent_example.py b/examples/01_simple_agent_example/simple_agent_example.py index 69dffc74..7562751b 100644 --- a/examples/01_simple_agent_example/simple_agent_example.py +++ b/examples/01_simple_agent_example/simple_agent_example.py @@ -48,16 +48,15 @@ async def start_agent() -> None: # } # ],) - # run till the call ends # await agent.say("Hello, how are you?") # await asyncio.sleep(5) - # Open the demo UI - await agent.edge.open_demo(call) # Open the demo UI await agent.edge.open_demo(call) await agent.simple_response("tell me something interesting in a short sentence") + + # run till the call ends await agent.finish() diff --git a/tests/test_tts_base.py b/tests/test_tts_base.py index 07fea2ec..477c214e 100644 --- a/tests/test_tts_base.py +++ b/tests/test_tts_base.py @@ -1,5 +1,3 @@ -from typing import AsyncIterator, Iterator - import pytest from vision_agents.core.tts.tts import TTS as BaseTTS @@ -7,43 +5,6 @@ from vision_agents.core.tts.testing import TTSSession -class DummyTTSBytesSingle(BaseTTS): - async def stream_audio(self, text: str, *_, **__) -> bytes: - # 16-bit PCM mono (s16), 100 samples -> 200 bytes - self._native_sample_rate = 16000 - self._native_channels = 1 - return b"\x00\x00" * 100 - - async def stop_audio(self) -> None: # pragma: no cover - noop - return None - - -class DummyTTSBytesAsync(BaseTTS): - async def stream_audio(self, text: str, *_, **__) -> AsyncIterator[bytes]: - self._native_sample_rate = 16000 - self._native_channels = 1 - - async def _agen(): - # Unaligned chunk sizes to test aggregator - yield b"\x00\x00" * 33 + b"\x00" # odd size - yield b"\x00\x00" * 10 - - return _agen() - - async def stop_audio(self) -> None: # pragma: no cover - noop - return None - - -class DummyTTSIterSync(BaseTTS): - async def stream_audio(self, text: str, *_, **__) -> Iterator[bytes]: - self._native_sample_rate = 16000 - self._native_channels = 1 - return iter([b"\x00\x00" * 50, b"\x00\x00" * 25]) - - async def stop_audio(self) -> None: # pragma: no cover - noop - return None - - class DummyTTSPcmStereoToMono(BaseTTS): async def stream_audio(self, text: str, *_, **__) -> PcmData: # 2 channels interleaved: 100 frames (per channel) -> 200 samples -> 400 bytes @@ -74,43 +35,6 @@ async def stop_audio(self) -> None: # pragma: no cover - noop return None -async def test_tts_bytes_single_emits_events_and_bytes(): - tts = DummyTTSBytesSingle() - tts.set_output_format(sample_rate=16000, channels=1) - session = TTSSession(tts) - - await tts.send("hello") - await tts.events.wait() - result = await session.wait_for_result(timeout=1.0) - - assert result.started - assert result.completed - assert len(session.speeches) == 1 - assert session.speeches[0] is not None - - -async def test_tts_bytes_async_aggregates_and_emits(): - tts = DummyTTSBytesAsync() - tts.set_output_format(sample_rate=16000, channels=1) - session = TTSSession(tts) - - await tts.send("hi") - await tts.events.wait() - - assert len(session.speeches) >= 1 - assert sum(len(c) for c in session.speeches) >= 2 * 33 # approx check - - -async def test_tts_iter_sync_emits_multiple_chunks(): - tts = DummyTTSIterSync() - tts.set_output_format(sample_rate=16000, channels=1) - session = TTSSession(tts) - - await tts.send("hello") - await tts.events.wait() - assert len(session.speeches) >= 2 - - async def test_tts_stereo_to_mono_halves_bytes(): tts = DummyTTSPcmStereoToMono() # desired mono, same sample rate From b7c57e9731d00382fb9a47af444d398cdd186acd Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 15:02:50 +0200 Subject: [PATCH 08/15] properly type the output track --- .../vision_agents/core/agents/agents.py | 18 +++++---------- .../vision_agents/core/edge/edge_transport.py | 11 +++++---- agents-core/vision_agents/core/edge/types.py | 23 ++++++++++++++++++- .../getstream/stream_edge_transport.py | 8 ++++--- 4 files changed, 39 insertions(+), 21 deletions(-) diff --git a/agents-core/vision_agents/core/agents/agents.py b/agents-core/vision_agents/core/agents/agents.py index d358fab5..9f6bee1f 100644 --- a/agents-core/vision_agents/core/agents/agents.py +++ b/agents-core/vision_agents/core/agents/agents.py @@ -2,10 +2,9 @@ import logging import time import uuid -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import uuid4 -import aiortc import getstream.models from aiortc import VideoStreamTrack from getstream.video.rtc import Call @@ -15,7 +14,7 @@ from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType from ..edge import sfu_events from ..edge.events import AudioReceivedEvent, TrackAddedEvent, CallEndedEvent -from ..edge.types import Connection, Participant, PcmData, User +from ..edge.types import Connection, Participant, PcmData, User, OutputAudioTrack from ..events.manager import EventManager from ..llm import events as llm_events from ..llm.events import ( @@ -161,7 +160,7 @@ def __init__( self._callback_executed = False self._track_tasks: Dict[str, asyncio.Task] = {} self._connection: Optional[Connection] = None - self._audio_track: Optional[aiortc.AudioStreamTrack] = None + self._audio_track: Optional[OutputAudioTrack] = None self._video_track: Optional[VideoStreamTrack] = None self._realtime_connection = None self._pc_track_handler_attached: bool = False @@ -308,15 +307,10 @@ async def on_realtime_agent_speech_transcription( original=event, ) - # Listen for TTS audio events and write audio to the output track @self.events.subscribe - async def _on_tts_audio(event: TTSAudioEvent): - try: - if self._audio_track and event.audio_data: - track_any = cast(Any, self._audio_track) - await track_any.write(event.audio_data) - except Exception as e: - self.logger.error(f"Error writing TTS audio to track: {e}") + async def _on_tts_audio_write_to_output(event: TTSAudioEvent): + if self._audio_track and event and event.audio_data is not None: + await self._audio_track.write(event.audio_data) @self.events.subscribe async def on_stt_transcript_event_create_response(event: STTTranscriptEvent): diff --git a/agents-core/vision_agents/core/edge/edge_transport.py b/agents-core/vision_agents/core/edge/edge_transport.py index 073728cf..89cb9fc6 100644 --- a/agents-core/vision_agents/core/edge/edge_transport.py +++ b/agents-core/vision_agents/core/edge/edge_transport.py @@ -1,6 +1,7 @@ """ Abstraction for stream vs other services here """ + import abc from typing import TYPE_CHECKING, Any, Optional @@ -8,10 +9,9 @@ import aiortc from pyee.asyncio import AsyncIOEventEmitter -from vision_agents.core.edge.types import User +from vision_agents.core.edge.types import User, OutputAudioTrack if TYPE_CHECKING: - pass @@ -31,7 +31,7 @@ async def create_user(self, user: User): pass @abc.abstractmethod - def create_audio_track(self): + def create_audio_track(self) -> OutputAudioTrack: pass @abc.abstractmethod @@ -55,6 +55,7 @@ async def create_conversation(self, call: Any, user: User, instructions): pass @abc.abstractmethod - def add_track_subscriber(self, track_id: str) -> Optional[aiortc.mediastreams.MediaStreamTrack]: + def add_track_subscriber( + self, track_id: str + ) -> Optional[aiortc.mediastreams.MediaStreamTrack]: pass - diff --git a/agents-core/vision_agents/core/edge/types.py b/agents-core/vision_agents/core/edge/types.py index fb5c1ba1..cd672482 100644 --- a/agents-core/vision_agents/core/edge/types.py +++ b/agents-core/vision_agents/core/edge/types.py @@ -1,5 +1,14 @@ from dataclasses import dataclass -from typing import Any, Optional, NamedTuple, Union, Iterator, AsyncIterator +from typing import ( + Any, + Optional, + NamedTuple, + Union, + Iterator, + AsyncIterator, + Protocol, + runtime_checkable, +) import logging import numpy as np @@ -34,6 +43,18 @@ async def close(self): pass +@runtime_checkable +class OutputAudioTrack(Protocol): + """ + A protocol describing an output audio track, the actual implementation depends on the edge transported used + eg. getstream.video.rtc.audio_track.AudioStreamTrack + """ + + async def write(self, data: bytes) -> None: ... + + def stop(self) -> None: ... + + class PcmData(NamedTuple): """ A named tuple representing PCM audio data. diff --git a/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py b/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py index de87850b..67131e7a 100644 --- a/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py +++ b/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py @@ -22,7 +22,7 @@ from vision_agents.core.edge import EdgeTransport, sfu_events from vision_agents.plugins.getstream.stream_conversation import StreamConversation -from vision_agents.core.edge.types import Connection, User +from vision_agents.core.edge.types import Connection, User, OutputAudioTrack from vision_agents.core.events.manager import EventManager from vision_agents.core.edge import events from vision_agents.core.utils import get_vision_agents_version @@ -104,7 +104,7 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent): track_type_int = event.payload.type # TrackType enum int from SFU expected_kind = self._get_webrtc_kind(track_type_int) track_key = (user_id, session_id, track_type_int) - is_agent_track = (user_id == self.agent_user_id) + is_agent_track = user_id == self.agent_user_id # First check if track already exists in map (e.g., from previous unpublish/republish) if track_key in self._track_map: @@ -288,7 +288,9 @@ async def on_audio_received(pcm: PcmData, participant: Participant): standardize_connection = StreamConnection(connection) return standardize_connection - def create_audio_track(self, framerate: int = 48000, stereo: bool = True): + def create_audio_track( + self, framerate: int = 48000, stereo: bool = True + ) -> OutputAudioTrack: return audio_track.AudioStreamTrack( framerate=framerate, stereo=stereo ) # default to webrtc framerate From 4e93b120baea5344918f468d42bc0ac41bc39a3a Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 21:35:35 +0200 Subject: [PATCH 09/15] working resampling mechanism --- agents-core/vision_agents/core/edge/types.py | 170 +++++++++--- agents-core/vision_agents/core/tts/tts.py | 138 +++++++++- conftest.py | 63 ++--- plugins/aws/README.md | 16 +- plugins/aws/example/aws_polly_tts_example.py | 2 +- plugins/aws/tests/test_tts.py | 2 +- plugins/cartesia/tests/test_tts.py | 2 +- plugins/elevenlabs/tests/test_tts.py | 2 +- plugins/fish/tests/test_fish_tts.py | 2 +- plugins/kokoro/tests/test_tts.py | 2 +- plugins/openai/tests/test_tts_openai.py | 2 +- tests/test_pcm_data.py | 267 +++++++++++++++++++ tests/test_resample_quality.py | 146 ++++++++++ 13 files changed, 718 insertions(+), 96 deletions(-) create mode 100644 tests/test_pcm_data.py create mode 100644 tests/test_resample_quality.py diff --git a/agents-core/vision_agents/core/edge/types.py b/agents-core/vision_agents/core/edge/types.py index cd672482..24b2fb6d 100644 --- a/agents-core/vision_agents/core/edge/types.py +++ b/agents-core/vision_agents/core/edge/types.py @@ -89,19 +89,36 @@ def duration(self) -> float: # For f32 format, each element in the array is one sample (float32) if isinstance(self.samples, np.ndarray): - # If array has shape (channels, samples), duration uses the samples dimension + # If array has shape (channels, samples) or (samples, channels), duration uses the samples dimension if self.samples.ndim == 2: - num_samples = self.samples.shape[-1] + # Determine which dimension is samples vs channels + # Standard format is (channels, samples), but we need to handle both + ch = self.channels if self.channels else 1 + if self.samples.shape[0] == ch: + # Shape is (channels, samples) - correct format + num_samples = self.samples.shape[1] + elif self.samples.shape[1] == ch: + # Shape is (samples, channels) - transposed format + num_samples = self.samples.shape[0] + else: + # Ambiguous or unknown - assume (channels, samples) and pick larger dimension + # This handles edge cases like (2, 2) arrays + num_samples = max(self.samples.shape[0], self.samples.shape[1]) else: num_samples = len(self.samples) elif isinstance(self.samples, bytes): # If samples is bytes, calculate based on format if self.format == "s16": # For s16 format, each sample is 2 bytes (16 bits) - num_samples = len(self.samples) // 2 + # For multi-channel, divide by channels to get sample count + num_samples = len(self.samples) // ( + 2 * (self.channels if self.channels else 1) + ) elif self.format == "f32": # For f32 format, each sample is 4 bytes (32 bits) - num_samples = len(self.samples) // 4 + num_samples = len(self.samples) // ( + 4 * (self.channels if self.channels else 1) + ) else: # Default assumption for other formats (treat as raw bytes) num_samples = len(self.samples) @@ -287,25 +304,39 @@ def resample( if self.sample_rate == target_sample_rate and target_channels == self.channels: return self - # Prepare ndarray shape for AV. - # Our convention: (channels, samples) for multi-channel, (samples,) for mono. - samples = self.samples - if samples.ndim == 1: - # Mono: reshape to (1, samples) for AV - samples = samples.reshape(1, -1) - elif samples.ndim == 2: - # Already (channels, samples) - pass - - # Create AV audio frame from the samples + # Prepare ndarray shape for AV input frame. + # Use planar input (s16p) with shape (channels, samples). in_layout = "mono" if self.channels == 1 else "stereo" - # For multi-channel, use planar format to avoid packed shape errors - in_format = "s16" if self.channels == 1 else "s16p" - samples = np.ascontiguousarray(samples) - frame = av.AudioFrame.from_ndarray(samples, format=in_format, layout=in_layout) + cmaj = self.samples + if isinstance(cmaj, np.ndarray): + if cmaj.ndim == 1: + # (samples,) -> (channels, samples) + if self.channels > 1: + cmaj = np.tile(cmaj, (self.channels, 1)) + else: + cmaj = cmaj.reshape(1, -1) + elif cmaj.ndim == 2: + # Normalize to (channels, samples) + ch = self.channels if self.channels else 1 + if cmaj.shape[0] == ch: + # Already (channels, samples) + pass + elif cmaj.shape[1] == ch: + # (samples, channels) -> transpose + cmaj = cmaj.T + else: + # Ambiguous - assume larger dim is samples + if cmaj.shape[1] > cmaj.shape[0]: + # Likely (channels, samples) + pass + else: + # Likely (samples, channels) + cmaj = cmaj.T + cmaj = np.ascontiguousarray(cmaj) + frame = av.AudioFrame.from_ndarray(cmaj, format="s16p", layout=in_layout) frame.sample_rate = self.sample_rate - # Create resampler + # Create resampler – output packed s16 out_layout = "mono" if target_channels == 1 else "stereo" resampler = av.AudioResampler( format="s16", layout=out_layout, rate=target_sample_rate @@ -315,20 +346,72 @@ def resample( resampled_frames = resampler.resample(frame) if resampled_frames: resampled_frame = resampled_frames[0] - resampled_samples = resampled_frame.to_ndarray() - - # AV returns (channels, samples), so for mono we want the first (and only) channel - if len(resampled_samples.shape) > 1: - if target_channels == 1: - resampled_samples = resampled_samples[0] + # PyAV's to_ndarray() for packed format returns flattened interleaved data + # For stereo s16 (packed), it returns shape (1, num_values) where num_values = samples * channels + raw_array = resampled_frame.to_ndarray() + num_frames = resampled_frame.samples # Actual number of sample frames + + # Normalize output to (channels, samples) format + ch = int(target_channels) + + # Handle PyAV's packed format quirk: returns (1, num_values) for stereo + if raw_array.ndim == 2 and raw_array.shape[0] == 1 and ch > 1: + # Flatten and deinterleave packed stereo data + # Shape (1, 32000) -> (32000,) -> deinterleave to (2, 16000) + flat = raw_array.reshape(-1) + if len(flat) == num_frames * ch: + # Deinterleave: [L0,R0,L1,R1,...] -> [[L0,L1,...], [R0,R1,...]] + resampled_samples = flat.reshape(-1, ch).T + else: + logger.warning( + "Unexpected array size %d for %d frames x %d channels", + len(flat), + num_frames, + ch, + ) + resampled_samples = flat.reshape(ch, -1) + elif raw_array.ndim == 2: + # Standard case: (samples, channels) or already (channels, samples) + if raw_array.shape[1] == ch: + # (samples, channels) -> transpose to (channels, samples) + resampled_samples = raw_array.T + elif raw_array.shape[0] == ch: + # Already (channels, samples) + resampled_samples = raw_array + else: + # Ambiguous - assume time-major + resampled_samples = raw_array.T + elif raw_array.ndim == 1: + # 1D output (mono) + if ch == 1: + # Keep as 1D for mono + resampled_samples = raw_array + elif ch > 1: + # Shouldn't happen if we requested stereo, but handle it + logger.warning( + "Got 1D array but requested %d channels, duplicating", ch + ) + resampled_samples = np.tile(raw_array, (ch, 1)) + else: + resampled_samples = raw_array + else: + # Unexpected dimensionality + logger.warning( + "Unexpected ndim %d from PyAV, reshaping", raw_array.ndim + ) + resampled_samples = raw_array.reshape(ch, -1) - # Convert to int16 - resampled_samples = resampled_samples.astype(np.int16) + # Ensure int16 dtype for s16 + if ( + isinstance(resampled_samples, np.ndarray) + and resampled_samples.dtype != np.int16 + ): + resampled_samples = resampled_samples.astype(np.int16) return PcmData( samples=resampled_samples, sample_rate=target_sample_rate, - format=self.format, + format="s16", pts=self.pts, dts=self.dts, time_base=self.time_base, @@ -339,13 +422,34 @@ def resample( return self def to_bytes(self) -> bytes: - """Return interleaved PCM bytes (s16 or f32 depending on format).""" + """Return interleaved PCM bytes (s16 or f32 depending on format). + + For multi-channel audio, this returns packed/interleaved bytes in the order + [L0, R0, L1, R1, ...]. The internal convention is (channels, samples). + If the stored ndarray is (samples, channels), we transpose it. + """ arr = self.samples if isinstance(arr, np.ndarray): if arr.ndim == 2: - # (channels, samples) -> interleaved (samples, channels) - interleaved = arr.T.reshape(-1) - return interleaved.tobytes() + channels = int(self.channels or arr.shape[0]) + # Normalize to (channels, samples) + if arr.shape[0] == channels: + cmaj = arr + elif arr.shape[1] == channels: + cmaj = arr.T + else: + logger.warning( + "to_bytes: ambiguous array shape %s for channels=%d; assuming time-major", + arr.shape, + channels, + ) + cmaj = arr.T + samples_count = cmaj.shape[1] + # Interleave channels explicitly to avoid any stride-related surprises + out = np.empty(samples_count * channels, dtype=cmaj.dtype) + for i in range(channels): + out[i::channels] = cmaj[i] + return out.tobytes() return arr.tobytes() # Fallback if isinstance(arr, (bytes, bytearray)): diff --git a/agents-core/vision_agents/core/tts/tts.py b/agents-core/vision_agents/core/tts/tts.py index 0b8f7340..d5677e46 100644 --- a/agents-core/vision_agents/core/tts/tts.py +++ b/agents-core/vision_agents/core/tts/tts.py @@ -1,4 +1,5 @@ import abc +import av import logging import time import uuid @@ -24,6 +25,7 @@ tts_events_emitted, ) from ..edge.types import PcmData +import numpy as np logger = logging.getLogger(__name__) @@ -68,6 +70,10 @@ def __init__(self, provider_name: Optional[str] = None): self._native_sample_rate: int = 16000 self._native_channels: int = 1 self._native_format: AudioFormat = AudioFormat.PCM_S16 + # Persistent resampler to avoid discontinuities between chunks + self._resampler = None + self._resampler_input_rate: Optional[int] = None + self._resampler_input_channels: Optional[int] = None def set_output_format( self, @@ -89,6 +95,46 @@ def set_output_format( self._desired_channels = int(channels) self._desired_format = audio_format + self._resampler = None + self._resampler_input_rate = None + self._resampler_input_channels = None + + def _get_resampler(self, input_rate: int, input_channels: int): + """Get or create a persistent resampler for the given input format. + + This avoids creating a new resampler for each chunk, which causes + discontinuities and clicking artifacts in the output audio. + + Args: + input_rate: Input sample rate + input_channels: Input channel count + + Returns: + PyAV AudioResampler instance + """ + + if self._resampler is not None and self._resampler_input_rate == input_rate and self._resampler_input_channels == input_channels: + return self._resampler + + in_layout = "mono" if input_channels == 1 else "stereo" + out_layout = "mono" if self._desired_channels == 1 else "stereo" + + self._resampler = av.AudioResampler( + format="s16", layout=out_layout, rate=self._desired_sample_rate + ) + self._resampler_input_rate = input_rate + self._resampler_input_channels = input_channels + + logger.debug( + "Created persistent resampler: %s@%dHz -> %s@%dHz", + in_layout, + input_rate, + out_layout, + self._desired_sample_rate, + ) + + return self._resampler + def _normalize_to_pcm(self, item: Union[bytes, bytearray, PcmData, Any]) -> PcmData: """Normalize a chunk to PcmData using the native provider format.""" if isinstance(item, PcmData): @@ -136,7 +182,92 @@ def _emit_chunk( user: Optional[Dict[str, Any]], ) -> tuple[int, float]: """Resample, serialize, emit TTSAudioEvent; return (bytes_len, duration_ms).""" - pcm_out = pcm.resample(self._desired_sample_rate, self._desired_channels) + + if ( + pcm.sample_rate == self._desired_sample_rate + and pcm.channels == self._desired_channels + ): + # No resampling needed + pcm_out = pcm + else: + resampler = self._get_resampler(pcm.sample_rate, pcm.channels) + + # Prepare input frame in planar format + samples = pcm.samples + if isinstance(samples, np.ndarray): + if samples.ndim == 1: + if pcm.channels > 1: + cmaj = np.tile(samples, (pcm.channels, 1)) + else: + cmaj = samples.reshape(1, -1) + elif samples.ndim == 2: + ch = pcm.channels if pcm.channels else 1 + if samples.shape[0] == ch: + cmaj = samples + elif samples.shape[1] == ch: + cmaj = samples.T + else: + if samples.shape[1] > samples.shape[0]: + cmaj = samples + else: + cmaj = samples.T + cmaj = np.ascontiguousarray(cmaj) + else: + # Shouldn't happen, but handle it + cmaj = ( + samples.reshape(1, -1) + if isinstance(samples, np.ndarray) + else samples + ) + + in_layout = "mono" if pcm.channels == 1 else "stereo" + frame = av.AudioFrame.from_ndarray(cmaj, format="s16p", layout=in_layout) + frame.sample_rate = pcm.sample_rate + + # Resample using persistent resampler + resampled_frames = resampler.resample(frame) + + if resampled_frames: + resampled_frame = resampled_frames[0] + raw_array = resampled_frame.to_ndarray() + num_frames = resampled_frame.samples + + # Handle PyAV's packed format quirk + ch = self._desired_channels + if raw_array.ndim == 2 and raw_array.shape[0] == 1 and ch > 1: + flat = raw_array.reshape(-1) + if len(flat) == num_frames * ch: + resampled_samples = flat.reshape(-1, ch).T + else: + resampled_samples = flat.reshape(ch, -1) + elif raw_array.ndim == 2: + if raw_array.shape[1] == ch: + resampled_samples = raw_array.T + elif raw_array.shape[0] == ch: + resampled_samples = raw_array + else: + resampled_samples = raw_array.T + elif raw_array.ndim == 1: + if ch == 1: + resampled_samples = raw_array + else: + resampled_samples = np.tile(raw_array, (ch, 1)) + else: + resampled_samples = raw_array.reshape(ch, -1) + + if resampled_samples.dtype != np.int16: + resampled_samples = resampled_samples.astype(np.int16) + + pcm_out = PcmData( + samples=resampled_samples, + sample_rate=self._desired_sample_rate, + format="s16", + channels=self._desired_channels, + ) + else: + # Resampling failed, use original + pcm_out = pcm + payload = pcm_out.to_bytes() # Metrics: counters per chunk attrs = {"tts_class": self.__class__.__name__} @@ -215,6 +346,11 @@ async def send( start_time = time.time() synthesis_id = str(uuid.uuid4()) + # Reset resampler for each new synthesis to ensure clean state + self._resampler = None + self._resampler_input_rate = None + self._resampler_input_channels = None + logger.debug( "Starting text-to-speech synthesis", extra={"text_length": len(text)} ) diff --git a/conftest.py b/conftest.py index da17b102..034e4e7a 100644 --- a/conftest.py +++ b/conftest.py @@ -21,14 +21,14 @@ class STTSession: """Helper class for testing STT implementations. - + Automatically subscribes to transcript and error events, collects them, and provides a convenient wait method. """ - + def __init__(self, stt): """Initialize STT session with an STT object. - + Args: stt: STT implementation to monitor """ @@ -36,39 +36,39 @@ def __init__(self, stt): self.transcripts = [] self.errors = [] self._event = asyncio.Event() - + # Subscribe to events @stt.events.subscribe async def on_transcript(event: STTTranscriptEvent): self.transcripts.append(event) self._event.set() - + @stt.events.subscribe async def on_error(event: STTErrorEvent): self.errors.append(event.error) self._event.set() - + self._on_transcript = on_transcript self._on_error = on_error - + async def wait_for_result(self, timeout: float = 30.0): """Wait for either a transcript or error event. - + Args: timeout: Maximum time to wait in seconds - + Raises: asyncio.TimeoutError: If no result received within timeout """ # Allow event subscriptions to be processed await asyncio.sleep(0.01) - + # Wait for an event await asyncio.wait_for(self._event.wait(), timeout=timeout) - + def get_full_transcript(self) -> str: """Get full transcription text from all transcript events. - + Returns: Combined text from all transcripts """ @@ -90,7 +90,7 @@ def assets_dir(): def mia_audio_16khz(): """Load mia.mp3 and convert to 16kHz PCM data.""" audio_file_path = os.path.join(get_assets_dir(), "mia.mp3") - + # Load audio file using PyAV container = av.open(audio_file_path) audio_stream = container.streams.audio[0] @@ -100,11 +100,7 @@ def mia_audio_16khz(): # Create resampler if needed resampler = None if original_sample_rate != target_rate: - resampler = av.AudioResampler( - format='s16', - layout='mono', - rate=target_rate - ) + resampler = av.AudioResampler(format="s16", layout="mono", rate=target_rate) # Read all audio frames samples = [] @@ -128,11 +124,7 @@ def mia_audio_16khz(): container.close() # Create PCM data - pcm = PcmData( - samples=samples, - sample_rate=target_rate, - format="s16" - ) + pcm = PcmData(samples=samples, sample_rate=target_rate, format="s16") return pcm @@ -141,7 +133,7 @@ def mia_audio_16khz(): def mia_audio_48khz(): """Load mia.mp3 and convert to 48kHz PCM data.""" audio_file_path = os.path.join(get_assets_dir(), "mia.mp3") - + # Load audio file using PyAV container = av.open(audio_file_path) audio_stream = container.streams.audio[0] @@ -151,11 +143,7 @@ def mia_audio_48khz(): # Create resampler if needed resampler = None if original_sample_rate != target_rate: - resampler = av.AudioResampler( - format='s16', - layout='mono', - rate=target_rate - ) + resampler = av.AudioResampler(format="s16", layout="mono", rate=target_rate) # Read all audio frames samples = [] @@ -179,11 +167,7 @@ def mia_audio_48khz(): container.close() # Create PCM data - pcm = PcmData( - samples=samples, - sample_rate=target_rate, - format="s16" - ) + pcm = PcmData(samples=samples, sample_rate=target_rate, format="s16") return pcm @@ -192,10 +176,10 @@ def mia_audio_48khz(): def golf_swing_image(): """Load golf_swing.png image and return as bytes.""" image_file_path = os.path.join(get_assets_dir(), "golf_swing.png") - + with open(image_file_path, "rb") as f: image_bytes = f.read() - + return image_bytes @@ -203,7 +187,7 @@ def golf_swing_image(): async def bunny_video_track(): """Create RealVideoTrack from video file.""" from aiortc import VideoStreamTrack - + video_file_path = os.path.join(get_assets_dir(), "bunny_3s.mp4") class RealVideoTrack(VideoStreamTrack): @@ -223,12 +207,12 @@ async def recv(self): for frame in self.container.decode(self.video_stream): if frame is None: raise asyncio.CancelledError("End of video stream") - + self.frame_count += 1 frame = frame.to_rgb() await asyncio.sleep(self.frame_duration) return frame - + raise asyncio.CancelledError("End of video stream") except asyncio.CancelledError: @@ -245,4 +229,3 @@ async def recv(self): yield track finally: track.container.close() - diff --git a/plugins/aws/README.md b/plugins/aws/README.md index ae1617ea..feb3b2b5 100644 --- a/plugins/aws/README.md +++ b/plugins/aws/README.md @@ -20,7 +20,7 @@ agent = Agent( agent_user=User(name="Friendly AI"), instructions="Be nice to the user", llm=aws.LLM(model="qwen.qwen3-32b-v1:0"), - tts=cartesia.TTS(), + tts=aws.TTS(), # using AWS Polly stt=deepgram.STT(), turn_detection=smart_turn.TurnDetection(buffer_duration=2.0, confidence_threshold=0.5), ) @@ -28,8 +28,6 @@ agent = Agent( The full example is available in example/aws_qwen_example.py -### Realtime Text/Image Usage - Nova sonic audio realtime STS is also supported: ```python @@ -43,18 +41,6 @@ agent = Agent( ### Polly TTS Usage -```python -from vision_agents.plugins import aws -from vision_agents.core.tts.manual_test import manual_tts_to_wav -import asyncio - -async def main(): - # For PCM, AWS Polly supports 8000 or 16000 Hz - tts = aws.TTS(voice_id="Joanna", sample_rate=16000) - await manual_tts_to_wav(tts, sample_rate=16000, channels=1) - -asyncio.run(main()) -``` ## Running the examples diff --git a/plugins/aws/example/aws_polly_tts_example.py b/plugins/aws/example/aws_polly_tts_example.py index c93f1880..7b962b37 100644 --- a/plugins/aws/example/aws_polly_tts_example.py +++ b/plugins/aws/example/aws_polly_tts_example.py @@ -9,7 +9,7 @@ async def main(): load_dotenv() tts = TTS(voice_id=os.environ.get("AWS_POLLY_VOICE", "Joanna")) - await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + await manual_tts_to_wav(tts, sample_rate=48000, channels=2) if __name__ == "__main__": diff --git a/plugins/aws/tests/test_tts.py b/plugins/aws/tests/test_tts.py index a881d02e..78be7b53 100644 --- a/plugins/aws/tests/test_tts.py +++ b/plugins/aws/tests/test_tts.py @@ -45,7 +45,7 @@ async def test_aws_polly_tts_speech(self, tts: aws_plugin.TTS): @pytest.mark.integration async def test_aws_polly_tts_manual_wav(self, tts: aws_plugin.TTS): - await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + await manual_tts_to_wav(tts, sample_rate=48000, channels=2) @pytest.mark.integration async def test_aws_polly_tts_non_blocking(self, tts: aws_plugin.TTS): diff --git a/plugins/cartesia/tests/test_tts.py b/plugins/cartesia/tests/test_tts.py index b90f5bf3..0046c335 100644 --- a/plugins/cartesia/tests/test_tts.py +++ b/plugins/cartesia/tests/test_tts.py @@ -32,7 +32,7 @@ async def test_cartesia_with_real_api(self, tts): @pytest.mark.integration async def test_cartesia_tts_convert_text_to_audio_manual_test(self, tts): - await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + await manual_tts_to_wav(tts, sample_rate=48000, channels=2) @pytest.mark.integration async def test_cartesia_tts_non_blocking(self, tts: cartesia.TTS): diff --git a/plugins/elevenlabs/tests/test_tts.py b/plugins/elevenlabs/tests/test_tts.py index 829e6be9..d92607a0 100644 --- a/plugins/elevenlabs/tests/test_tts.py +++ b/plugins/elevenlabs/tests/test_tts.py @@ -30,7 +30,7 @@ async def test_elevenlabs_with_real_api(self, tts): @pytest.mark.integration async def test_elevenlabs_tts_convert_text_to_audio_manual_test(self, tts): - path = await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + path = await manual_tts_to_wav(tts, sample_rate=48000, channels=2) print("ElevenLabs TTS audio written to:", path) @pytest.mark.integration diff --git a/plugins/fish/tests/test_fish_tts.py b/plugins/fish/tests/test_fish_tts.py index a1450201..82985e56 100644 --- a/plugins/fish/tests/test_fish_tts.py +++ b/plugins/fish/tests/test_fish_tts.py @@ -17,7 +17,7 @@ async def tts(self) -> fish.TTS: @pytest.mark.integration async def test_fish_tts_convert_text_to_audio_manual_test(self, tts: fish.TTS): - await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + await manual_tts_to_wav(tts, sample_rate=48000, channels=2) @pytest.mark.integration async def test_fish_tts_convert_text_to_audio(self, tts: fish.TTS): diff --git a/plugins/kokoro/tests/test_tts.py b/plugins/kokoro/tests/test_tts.py index c19d6b03..551d5f67 100644 --- a/plugins/kokoro/tests/test_tts.py +++ b/plugins/kokoro/tests/test_tts.py @@ -15,4 +15,4 @@ def tts(self): # returns kokoro TTS if available @pytest.mark.integration async def test_kokoro_tts_convert_text_to_audio_manual_test(self, tts): - await manual_tts_to_wav(tts, sample_rate=24000, channels=1) + await manual_tts_to_wav(tts, sample_rate=48000, channels=2) diff --git a/plugins/openai/tests/test_tts_openai.py b/plugins/openai/tests/test_tts_openai.py index d00ae4ba..8624b941 100644 --- a/plugins/openai/tests/test_tts_openai.py +++ b/plugins/openai/tests/test_tts_openai.py @@ -28,7 +28,7 @@ async def test_openai_tts_speech(self, tts: openai_plugin.TTS): @pytest.mark.integration async def test_openai_tts_manual_wav(self, tts: openai_plugin.TTS): - await manual_tts_to_wav(tts, sample_rate=16000, channels=1) + await manual_tts_to_wav(tts, sample_rate=48000, channels=2) @pytest.mark.integration async def test_openai_tts_non_blocking(self, tts: openai_plugin.TTS): diff --git a/tests/test_pcm_data.py b/tests/test_pcm_data.py new file mode 100644 index 00000000..726c2835 --- /dev/null +++ b/tests/test_pcm_data.py @@ -0,0 +1,267 @@ +import numpy as np + +from vision_agents.core.edge.types import PcmData + + +def _i16_list_from_bytes(b: bytes): + return list(np.frombuffer(b, dtype=np.int16)) + + +def test_to_bytes_interleaves_from_channel_major(): + # Create (channels, samples) data: L=[1,2,3,4], R=[-1,-2,-3,-4] + samples = np.array( + [ + [1, 2, 3, 4], + [-1, -2, -3, -4], + ], + dtype=np.int16, + ) + pcm = PcmData(samples=samples, sample_rate=16000, format="s16", channels=2) + out = _i16_list_from_bytes(pcm.to_bytes()) + assert out == [1, -1, 2, -2, 3, -3, 4, -4] + + +def test_to_bytes_interleaves_from_time_major(): + # Create (samples, channels) data: time-major + time_major = np.array( + [ + [1, -1], + [2, -2], + [3, -3], + [4, -4], + ], + dtype=np.int16, + ) + pcm = PcmData(samples=time_major, sample_rate=16000, format="s16", channels=2) + out = _i16_list_from_bytes(pcm.to_bytes()) + assert out == [1, -1, 2, -2, 3, -3, 4, -4] + + +def test_resample_upmix_produces_channel_major_and_interleaved_bytes(): + # Mono ramp 1..10 + mono = np.arange(1, 11, dtype=np.int16) + pcm_mono = PcmData(samples=mono, sample_rate=16000, format="s16", channels=1) + + # Upmix to stereo (same sample rate) + pcm_stereo = pcm_mono.resample(16000, target_channels=2) + assert pcm_stereo.channels == 2 + assert hasattr(pcm_stereo, "samples") + assert isinstance(pcm_stereo.samples, np.ndarray) + assert pcm_stereo.samples.ndim == 2 + # Expect (channels, samples) shape + assert pcm_stereo.samples.shape[0] == 2 + # Sample count may be >= input due to resampler buffering; check prefix + assert pcm_stereo.samples.shape[1] >= mono.shape[0] + # Both channels should be identical after upmix + assert np.array_equal(pcm_stereo.samples[0], pcm_stereo.samples[1]) + + # Bytes should be interleaved L0,R0,L1,R1,... + out_bytes = pcm_stereo.to_bytes() + # Verify interleaving pattern: L[i] == R[i] for a prefix + out_i16 = _i16_list_from_bytes(out_bytes) + # take first 2 * N pairs (N from input) + pairs = min(len(mono), len(out_i16) // 2) + left = out_i16[0 : 2 * pairs : 2] + right = out_i16[1 : 2 * pairs : 2] + assert left == right + + +def test_resample_rate_and_stereo_size_scaling(): + # 200 mono samples @16kHz -> expect ~3x samples at 48kHz and x2 for stereo + mono = np.arange(200, dtype=np.int16) + pcm_mono = PcmData(samples=mono, sample_rate=16000, format="s16", channels=1) + + pcm_48k_stereo = pcm_mono.resample(48000, target_channels=2) + out = pcm_48k_stereo.to_bytes() + + # 16-bit stereo -> 4 bytes per sample frame + # 20ms at 48k is 960 frames = 3840 bytes; our total depends on input size + # Sanity: output length should be >= input_bytes * 6 - small tolerance + input_bytes = mono.nbytes + assert len(out) >= input_bytes * 5 # conservative lower bound + + +# ===== Bug reproduction tests ===== + + +def test_bug_mono_to_stereo_duration_preserved(): + """ + BUG REPRODUCTION: Converting mono to stereo should preserve duration. + If duration changes, playback will be slowed down or sped up. + """ + # Create 1 second of mono audio at 16kHz + sample_rate = 16000 + duration_sec = 1.0 + num_samples = int(sample_rate * duration_sec) + + # Generate a simple sine wave + t = np.linspace(0, duration_sec, num_samples, dtype=np.float32) + audio = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + + pcm_mono = PcmData(samples=audio, sample_rate=sample_rate, format="s16", channels=1) + + # Check initial duration + mono_duration = pcm_mono.duration + print(f"\nMono duration: {mono_duration}s (expected ~1.0s)") + assert abs(mono_duration - duration_sec) < 0.01, ( + f"Mono duration should be ~{duration_sec}s, got {mono_duration}s" + ) + + # Convert to stereo (no resampling, just channel conversion) + pcm_stereo = pcm_mono.resample(sample_rate, target_channels=2) + + # Duration should be EXACTLY the same + stereo_duration = pcm_stereo.duration + print(f"Stereo duration: {stereo_duration}s (expected ~1.0s)") + print(f"Stereo shape: {pcm_stereo.samples.shape} (expected (2, {num_samples}))") + + assert abs(stereo_duration - duration_sec) < 0.01, ( + f"Stereo duration should be ~{duration_sec}s, got {stereo_duration}s (BUG: playback will be wrong!)" + ) + + # Verify shape is correct (channels, samples) + assert pcm_stereo.samples.shape[0] == 2, ( + f"First dimension should be channels (2), got shape {pcm_stereo.samples.shape}" + ) + assert pcm_stereo.samples.shape[1] >= num_samples - 10, ( + f"Second dimension should be ~samples ({num_samples}), got shape {pcm_stereo.samples.shape}" + ) + + +def test_bug_resample_16khz_to_48khz_quality(): + """ + BUG REPRODUCTION: Resampling 16kHz to 48kHz should produce correct sample count. + If sample count is wrong, quality will be bad. + """ + # Create 1 second of mono audio at 16kHz + sample_rate_in = 16000 + sample_rate_out = 48000 + duration_sec = 1.0 + num_samples_in = int(sample_rate_in * duration_sec) + + # Generate a simple sine wave + t = np.linspace(0, duration_sec, num_samples_in, dtype=np.float32) + audio = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + + pcm_16k = PcmData( + samples=audio, sample_rate=sample_rate_in, format="s16", channels=1 + ) + + # Resample to 48kHz + pcm_48k = pcm_16k.resample(sample_rate_out, target_channels=1) + + # Check that sample count increased by 3x (48k/16k = 3) + expected_samples = num_samples_in * 3 + actual_samples = ( + len(pcm_48k.samples) if pcm_48k.samples.ndim == 1 else pcm_48k.samples.shape[-1] + ) + + print(f"\n16kHz samples: {num_samples_in}") + print(f"48kHz samples: {actual_samples} (expected ~{expected_samples})") + print(f"48kHz shape: {pcm_48k.samples.shape}") + print(f"48kHz duration: {pcm_48k.duration}s (expected ~1.0s)") + + # Allow some tolerance for resampler edge effects + assert abs(actual_samples - expected_samples) < 100, ( + f"Expected ~{expected_samples} samples at 48kHz, got {actual_samples} (BUG: quality will be bad!)" + ) + + # Duration should remain the same + assert abs(pcm_48k.duration - duration_sec) < 0.01, ( + f"Duration should remain ~{duration_sec}s, got {pcm_48k.duration}s" + ) + + +def test_bug_resample_16khz_to_48khz_stereo_combined(): + """ + BUG REPRODUCTION: The worst case - 16kHz mono to 48kHz stereo. + This combines both bugs: resampling quality AND duration preservation. + """ + # Create 1 second of mono audio at 16kHz + sample_rate_in = 16000 + sample_rate_out = 48000 + duration_sec = 1.0 + num_samples_in = int(sample_rate_in * duration_sec) + + # Generate a simple sine wave + t = np.linspace(0, duration_sec, num_samples_in, dtype=np.float32) + audio = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + + pcm_16k = PcmData( + samples=audio, sample_rate=sample_rate_in, format="s16", channels=1 + ) + + # Resample to 48kHz stereo (the problematic case!) + pcm_48k_stereo = pcm_16k.resample(sample_rate_out, target_channels=2) + + print(f"\n16kHz mono shape: {pcm_16k.samples.shape}") + print(f"48kHz stereo shape: {pcm_48k_stereo.samples.shape}") + print(f"48kHz stereo duration: {pcm_48k_stereo.duration}s (expected ~1.0s)") + + # Check shape is correct (channels, samples) + assert pcm_48k_stereo.samples.ndim == 2, "Should be 2D array" + assert pcm_48k_stereo.samples.shape[0] == 2, ( + f"First dimension should be channels (2), got shape {pcm_48k_stereo.samples.shape}" + ) + + # Check that sample count increased by 3x (48k/16k = 3) + expected_samples = num_samples_in * 3 + actual_samples = pcm_48k_stereo.samples.shape[1] + + print(f"Expected ~{expected_samples} samples, got {actual_samples}") + + # Allow some tolerance for resampler edge effects + assert abs(actual_samples - expected_samples) < 100, ( + f"Expected ~{expected_samples} samples at 48kHz, got {actual_samples}" + ) + + # Duration should remain the same (THIS IS THE CRITICAL BUG) + assert abs(pcm_48k_stereo.duration - duration_sec) < 0.01, ( + f"Duration should remain ~{duration_sec}s, got {pcm_48k_stereo.duration}s (BUG: causes slow playback!)" + ) + + +def test_bug_duration_with_different_array_shapes(): + """ + BUG REPRODUCTION: Duration calculation should work with any array shape. + The bug is that shape[-1] is used, which gives wrong results for (samples, channels) arrays. + """ + sample_rate = 16000 + num_samples = 16000 # 1 second + expected_duration = 1.0 + + # Test 1: 1D array (mono) - should work + samples_1d = np.zeros(num_samples, dtype=np.int16) + pcm_1d = PcmData( + samples=samples_1d, sample_rate=sample_rate, format="s16", channels=1 + ) + print(f"\n1D mono: shape={pcm_1d.samples.shape}, duration={pcm_1d.duration}s") + assert abs(pcm_1d.duration - expected_duration) < 0.01 + + # Test 2: 2D array (channels, samples) - CORRECT format, should work + samples_2d_correct = np.zeros((2, num_samples), dtype=np.int16) + pcm_2d_correct = PcmData( + samples=samples_2d_correct, sample_rate=sample_rate, format="s16", channels=2 + ) + print( + f"2D (channels, samples): shape={pcm_2d_correct.samples.shape}, duration={pcm_2d_correct.duration}s" + ) + assert abs(pcm_2d_correct.duration - expected_duration) < 0.01 + + # Test 3: 2D array (samples, channels) - WRONG format but might happen from PyAV + # This is where the bug manifests! + samples_2d_wrong = np.zeros((num_samples, 2), dtype=np.int16) + pcm_2d_wrong = PcmData( + samples=samples_2d_wrong, sample_rate=sample_rate, format="s16", channels=2 + ) + wrong_duration = pcm_2d_wrong.duration + print( + f"2D (samples, channels): shape={pcm_2d_wrong.samples.shape}, duration={wrong_duration}s" + ) + + # With current buggy code using shape[-1], this will give duration = 2/16000 = 0.000125s + # But we want it to be 1.0s + # This assertion will FAIL with the bug + assert abs(wrong_duration - expected_duration) < 0.01, ( + f"Duration with (samples, channels) shape is WRONG: {wrong_duration}s (expected {expected_duration}s)" + ) diff --git a/tests/test_resample_quality.py b/tests/test_resample_quality.py new file mode 100644 index 00000000..e19c7c50 --- /dev/null +++ b/tests/test_resample_quality.py @@ -0,0 +1,146 @@ +"""Test to investigate resampling quality issues.""" + +import numpy as np +from vision_agents.core.edge.types import PcmData +import av + + +def test_compare_resampling_methods(): + """Compare PyAV resampling with scipy for quality.""" + # Create 1 second of clean sine wave at 16kHz + sample_rate_in = 16000 + sample_rate_out = 48000 + duration = 1.0 + freq = 440 # A4 note + + num_samples_in = int(sample_rate_in * duration) + t = np.linspace(0, duration, num_samples_in, dtype=np.float64) + + # Generate clean sine wave + sine_wave = np.sin(2 * np.pi * freq * t) + audio_int16 = (sine_wave * 32767).astype(np.int16) + + # Method 1: PyAV (current implementation) + pcm_16k = PcmData( + samples=audio_int16, sample_rate=sample_rate_in, format="s16", channels=1 + ) + pcm_48k_pyav = pcm_16k.resample(sample_rate_out, target_channels=1) + + print("\n=== PyAV Resampler ===") + print(f"Input: {len(audio_int16)} samples @ {sample_rate_in}Hz") + print(f"Output: {len(pcm_48k_pyav.samples)} samples @ {sample_rate_out}Hz") + print(f"Output dtype: {pcm_48k_pyav.samples.dtype}") + print(f"Output shape: {pcm_48k_pyav.samples.shape}") + + # Check for clipping or artifacts + pyav_samples = ( + pcm_48k_pyav.samples.flatten() + if pcm_48k_pyav.samples.ndim > 1 + else pcm_48k_pyav.samples + ) + print(f"Output min: {pyav_samples.min()}, max: {pyav_samples.max()}") + + # Check for discontinuities (potential clicks) + diffs = np.abs(np.diff(pyav_samples.astype(np.float32))) + max_jump = np.max(diffs) + mean_jump = np.mean(diffs) + print(f"Max sample-to-sample jump: {max_jump:.1f}") + print(f"Mean sample-to-sample jump: {mean_jump:.1f}") + + # Large jumps indicate clicks + large_jumps = np.sum(diffs > 10000) + print(f"Number of large jumps (>10000): {large_jumps}") + + # Method 2: Try scipy for comparison + try: + from scipy import signal + + # Resample using scipy's high-quality resampler + num_samples_out = int(len(audio_int16) * sample_rate_out / sample_rate_in) + audio_float = audio_int16.astype(np.float32) / 32768.0 + resampled_scipy = signal.resample(audio_float, num_samples_out) + resampled_scipy_int16 = (np.clip(resampled_scipy, -1.0, 1.0) * 32767).astype( + np.int16 + ) + + print("\n=== SciPy Resampler ===") + print(f"Output: {len(resampled_scipy_int16)} samples @ {sample_rate_out}Hz") + print( + f"Output min: {resampled_scipy_int16.min()}, max: {resampled_scipy_int16.max()}" + ) + + diffs_scipy = np.abs(np.diff(resampled_scipy_int16.astype(np.float32))) + max_jump_scipy = np.max(diffs_scipy) + mean_jump_scipy = np.mean(diffs_scipy) + print(f"Max sample-to-sample jump: {max_jump_scipy:.1f}") + print(f"Mean sample-to-sample jump: {mean_jump_scipy:.1f}") + + large_jumps_scipy = np.sum(diffs_scipy > 10000) + print(f"Number of large jumps (>10000): {large_jumps_scipy}") + + # Save both for manual inspection + import wave + import io + + def save_wav(samples, sr, filename): + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sr) + wf.writeframes(samples.tobytes()) + with open(filename, "wb") as f: + f.write(buf.getvalue()) + + save_wav(audio_int16, sample_rate_in, "/tmp/original_16k.wav") + save_wav(pyav_samples.astype(np.int16), sample_rate_out, "/tmp/pyav_48k.wav") + save_wav(resampled_scipy_int16, sample_rate_out, "/tmp/scipy_48k.wav") + + print("\nWAV files saved to /tmp/ for comparison") + + except ImportError: + print("\nSciPy not available for comparison") + + +def test_pyav_resampler_settings(): + """Check if PyAV resampler has quality settings we're missing.""" + sample_rate_in = 16000 + sample_rate_out = 48000 + num_samples = 16000 + + # Create test signal + t = np.linspace(0, 1.0, num_samples, dtype=np.float64) + audio = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + + # Create frame + frame = av.AudioFrame.from_ndarray( + audio.reshape(1, -1), format="s16p", layout="mono" + ) + frame.sample_rate = sample_rate_in + + # Try different resampler configurations + print("\n=== Testing PyAV Resampler Options ===") + + # Default resampler + resampler_default = av.AudioResampler( + format="s16", layout="mono", rate=sample_rate_out + ) + + print("Default resampler created") + print(f"Resampler: {resampler_default}") + + # Check if there are any quality options available + # Note: PyAV/FFmpeg's libswresample has quality options but might not be exposed + + frames = resampler_default.resample(frame) + if frames: + result = frames[0].to_ndarray().flatten() + print(f"Default output: {len(result)} samples") + + diffs = np.abs(np.diff(result.astype(np.float32))) + print(f"Max jump: {np.max(diffs):.1f}, Mean jump: {np.mean(diffs):.1f}") + + +if __name__ == "__main__": + test_compare_resampling_methods() + test_pyav_resampler_settings() From bba5ea7af7502b7c6c84c0efece4428efd3708f2 Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 21:58:00 +0200 Subject: [PATCH 10/15] working resampling mechanism --- DEVELOPMENT.md | 25 ++++++ agents-core/vision_agents/core/edge/types.py | 20 +++-- agents-core/vision_agents/core/tts/tts.py | 90 ++------------------ 3 files changed, 44 insertions(+), 91 deletions(-) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index d3b251b6..50540a3f 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -109,6 +109,31 @@ To see how the agent work open up agents.py * The LLM uses the VideoForwarder to write the video to a websocket or webrtc connection * The STS writes the reply on agent.llm.audio_track and the RealtimeTranscriptEvent / RealtimePartialTranscriptEvent +## Audio management + +Some important things about audio inside the library: + +1. WebRTC uses Opus 48khz stereo but inside the library audio is always in PCM format +2. Plugins / AI models work with different PCM formats, usually 16khz mono +3. PCM data is always passed around using the `PcmData` object which contains information about sample rate, channels and format +4. Text-to-speech plugins automatically return PCM in the format needed by WebRTC. This is exposed via the `set_output_format` method +5. Audio resampling can be done using `PcmData.resample` method +6. When resampling audio in chunks, it is important to re-use the same `av.AudioResampler` resampler (see `PcmData.resample` and `core.tts.TTS`) +7. Adjusting from stereo to mono and vice-versa can be done using the `PcmData.resample` method + +Some ground rules: + +1. Do not build code to resample / adjust audio unless it is not covered already by `PcmData` +2. Do not pass PCM as plain bytes around and write code that assumes specific sample rate or format. Use `PcmData` instead + +### Testing audio manually + +Sometimes you need to test audio manually, here's some tips: + +1. Do not use earplugs when testing PCM playback ;) +2. You can use the `PcmData.to_wav_bytes` method to convert PCM into wav bytes (see `manual_tts_to_wav` for an example) +3. If you have `ffplay` installed, you can playback pcm directly to check if audio is correct + ## Dev / Contributor Guidelines ### Light wrapping diff --git a/agents-core/vision_agents/core/edge/types.py b/agents-core/vision_agents/core/edge/types.py index 24b2fb6d..e1170136 100644 --- a/agents-core/vision_agents/core/edge/types.py +++ b/agents-core/vision_agents/core/edge/types.py @@ -287,7 +287,10 @@ def from_data( raise TypeError(f"Unsupported data type for PcmData: {type(data)}") def resample( - self, target_sample_rate: int, target_channels: Optional[int] = None + self, + target_sample_rate: int, + target_channels: Optional[int] = None, + resampler: Optional[Any] = None, ) -> "PcmData": """ Resample PcmData to a different sample rate and/or channels using AV library. @@ -295,6 +298,9 @@ def resample( Args: target_sample_rate: Target sample rate in Hz target_channels: Target number of channels (defaults to current) + resampler: Optional persistent AudioResampler instance to use. If None, + creates a new resampler (for one-off use). Pass a persistent + resampler to avoid discontinuities when resampling streaming chunks. Returns: New PcmData object with resampled audio @@ -336,11 +342,13 @@ def resample( frame = av.AudioFrame.from_ndarray(cmaj, format="s16p", layout=in_layout) frame.sample_rate = self.sample_rate - # Create resampler – output packed s16 - out_layout = "mono" if target_channels == 1 else "stereo" - resampler = av.AudioResampler( - format="s16", layout=out_layout, rate=target_sample_rate - ) + # Use provided resampler or create a new one + if resampler is None: + # Create new resampler for one-off use + out_layout = "mono" if target_channels == 1 else "stereo" + resampler = av.AudioResampler( + format="s16", layout=out_layout, rate=target_sample_rate + ) # Resample the frame resampled_frames = resampler.resample(frame) diff --git a/agents-core/vision_agents/core/tts/tts.py b/agents-core/vision_agents/core/tts/tts.py index d5677e46..2f0c8f2e 100644 --- a/agents-core/vision_agents/core/tts/tts.py +++ b/agents-core/vision_agents/core/tts/tts.py @@ -182,91 +182,11 @@ def _emit_chunk( user: Optional[Dict[str, Any]], ) -> tuple[int, float]: """Resample, serialize, emit TTSAudioEvent; return (bytes_len, duration_ms).""" - - if ( - pcm.sample_rate == self._desired_sample_rate - and pcm.channels == self._desired_channels - ): - # No resampling needed - pcm_out = pcm - else: - resampler = self._get_resampler(pcm.sample_rate, pcm.channels) - - # Prepare input frame in planar format - samples = pcm.samples - if isinstance(samples, np.ndarray): - if samples.ndim == 1: - if pcm.channels > 1: - cmaj = np.tile(samples, (pcm.channels, 1)) - else: - cmaj = samples.reshape(1, -1) - elif samples.ndim == 2: - ch = pcm.channels if pcm.channels else 1 - if samples.shape[0] == ch: - cmaj = samples - elif samples.shape[1] == ch: - cmaj = samples.T - else: - if samples.shape[1] > samples.shape[0]: - cmaj = samples - else: - cmaj = samples.T - cmaj = np.ascontiguousarray(cmaj) - else: - # Shouldn't happen, but handle it - cmaj = ( - samples.reshape(1, -1) - if isinstance(samples, np.ndarray) - else samples - ) - - in_layout = "mono" if pcm.channels == 1 else "stereo" - frame = av.AudioFrame.from_ndarray(cmaj, format="s16p", layout=in_layout) - frame.sample_rate = pcm.sample_rate - - # Resample using persistent resampler - resampled_frames = resampler.resample(frame) - - if resampled_frames: - resampled_frame = resampled_frames[0] - raw_array = resampled_frame.to_ndarray() - num_frames = resampled_frame.samples - - # Handle PyAV's packed format quirk - ch = self._desired_channels - if raw_array.ndim == 2 and raw_array.shape[0] == 1 and ch > 1: - flat = raw_array.reshape(-1) - if len(flat) == num_frames * ch: - resampled_samples = flat.reshape(-1, ch).T - else: - resampled_samples = flat.reshape(ch, -1) - elif raw_array.ndim == 2: - if raw_array.shape[1] == ch: - resampled_samples = raw_array.T - elif raw_array.shape[0] == ch: - resampled_samples = raw_array - else: - resampled_samples = raw_array.T - elif raw_array.ndim == 1: - if ch == 1: - resampled_samples = raw_array - else: - resampled_samples = np.tile(raw_array, (ch, 1)) - else: - resampled_samples = raw_array.reshape(ch, -1) - - if resampled_samples.dtype != np.int16: - resampled_samples = resampled_samples.astype(np.int16) - - pcm_out = PcmData( - samples=resampled_samples, - sample_rate=self._desired_sample_rate, - format="s16", - channels=self._desired_channels, - ) - else: - # Resampling failed, use original - pcm_out = pcm + # Resample using persistent resampler to avoid discontinuities between chunks + resampler = self._get_resampler(pcm.sample_rate, pcm.channels) + pcm_out = pcm.resample( + self._desired_sample_rate, self._desired_channels, resampler=resampler + ) payload = pcm_out.to_bytes() # Metrics: counters per chunk From 353041ba1ebdea3f321d7bb4d5a97570787dcf56 Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 22:32:48 +0200 Subject: [PATCH 11/15] working resampling mechanism --- DEVELOPMENT.md | 42 ++++++++++++ agents-core/vision_agents/core/edge/types.py | 66 +++++++++++++++++++ .../vision_agents/core/tts/manual_test.py | 40 ++++------- agents-core/vision_agents/core/tts/tts.py | 51 +++++--------- 4 files changed, 135 insertions(+), 64 deletions(-) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 50540a3f..c953a070 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -126,6 +126,48 @@ Some ground rules: 1. Do not build code to resample / adjust audio unless it is not covered already by `PcmData` 2. Do not pass PCM as plain bytes around and write code that assumes specific sample rate or format. Use `PcmData` instead +## Example + +```python +import asyncio +from vision_agents.core.edge.types import PcmData +from openai import AsyncOpenAI + +async def example(): + client = AsyncOpenAI(api_key="sk-42") + + resp = await client.audio.speech.create( + model="gpt-4o-mini-tts", + voice="alloy", + input="pcm is cool, give me some of that please", + response_format="pcm", + ) + + # load response into PcmData, note that you need to specify sample_rate, channels and format + pcm_data = PcmData.from_bytes( + resp.content, sample_rate=24_000, channels=1, format="s16" + ) + + # check if pcm_data is stereo (it's not in this case ofc) + print(pcm_data.stereo) + + # write the pcm to file + with open("test.wav", "wb") as f: + f.write(pcm_data.to_wav_bytes()) + + # resample pcm to be 48khz stereo + resampled_pcm = pcm_data.resample(48_000, 2) + + # play-out pcm using ffplay + from vision_agents.core.edge.types import play_pcm_with_ffplay + + await play_pcm_with_ffplay(resampled_pcm) + +if __name__ == "__main__": + asyncio.run(example()) +``` + + ### Testing audio manually Sometimes you need to test audio manually, here's some tips: diff --git a/agents-core/vision_agents/core/edge/types.py b/agents-core/vision_agents/core/edge/types.py index e1170136..95ba2d6e 100644 --- a/agents-core/vision_agents/core/edge/types.py +++ b/agents-core/vision_agents/core/edge/types.py @@ -15,6 +15,11 @@ from numpy._typing import NDArray from pyee.asyncio import AsyncIOEventEmitter import av +import asyncio +import os +import shutil +import tempfile +import time logger = logging.getLogger(__name__) @@ -76,6 +81,10 @@ class PcmData(NamedTuple): time_base: Optional[float] = None # Time base for converting timestamps to seconds channels: int = 1 # Number of channels (1=mono, 2=stereo) + @property + def stereo(self) -> bool: + return self.channels == 2 + @property def duration(self) -> float: """ @@ -636,3 +645,60 @@ def _gen(): raise TypeError( f"Unsupported response type for PcmData.from_response: {type(response)}" ) + + +async def play_pcm_with_ffplay( + pcm: PcmData, + outfile_path: Optional[str] = None, + timeout_s: float = 30.0, +) -> str: + """Write PcmData to a WAV file and optionally play it with ffplay. + + This is a utility function for testing and debugging audio output. + + Args: + pcm: PcmData object to play + outfile_path: Optional path for the WAV file. If None, creates a temp file. + timeout_s: Timeout in seconds for ffplay playback (default: 30.0) + + Returns: + Path to the written WAV file + + Example: + pcm = PcmData.from_bytes(audio_bytes, sample_rate=48000, channels=2) + wav_path = await play_pcm_with_ffplay(pcm) + """ + + # Generate output path if not provided + if outfile_path is None: + tmpdir = tempfile.gettempdir() + timestamp = int(time.time()) + outfile_path = os.path.join(tmpdir, f"pcm_playback_{timestamp}.wav") + + # Write WAV file + with open(outfile_path, "wb") as f: + f.write(pcm.to_wav_bytes()) + + logger.info(f"Wrote WAV file: {outfile_path}") + + # Optional playback with ffplay + if shutil.which("ffplay"): + logger.info("Playing audio with ffplay...") + proc = await asyncio.create_subprocess_exec( + "ffplay", + "-autoexit", + "-nodisp", + "-hide_banner", + "-loglevel", + "error", + outfile_path, + ) + try: + await asyncio.wait_for(proc.wait(), timeout=timeout_s) + except asyncio.TimeoutError: + logger.warning(f"ffplay timed out after {timeout_s}s, killing process") + proc.kill() + else: + logger.warning("ffplay not found in PATH, skipping playback") + + return outfile_path diff --git a/agents-core/vision_agents/core/tts/manual_test.py b/agents-core/vision_agents/core/tts/manual_test.py index 4d2473d4..836ab05b 100644 --- a/agents-core/vision_agents/core/tts/manual_test.py +++ b/agents-core/vision_agents/core/tts/manual_test.py @@ -1,13 +1,12 @@ -import asyncio import os -import shutil import tempfile import time + from typing import Optional from vision_agents.core.tts import TTS from vision_agents.core.tts.testing import TTSSession -from vision_agents.core.edge.types import PcmData +from vision_agents.core.edge.types import PcmData, play_pcm_with_ffplay async def manual_tts_to_wav( @@ -18,7 +17,6 @@ async def manual_tts_to_wav( text: str = "This is a manual TTS playback test.", outfile_path: Optional[str] = None, timeout_s: float = 20.0, - play_env: str = "FFPLAY", ) -> str: """Generate TTS audio to a WAV file and optionally play with ffplay. @@ -48,35 +46,19 @@ async def manual_tts_to_wav( if result.errors: raise RuntimeError(f"TTS errors: {result.errors}") - # Write WAV file (16kHz mono, s16) - if outfile_path is None: - tmpdir = tempfile.gettempdir() - timestamp = int(time.time()) - outfile_path = os.path.join( - tmpdir, f"tts_manual_test_{tts.__class__.__name__}_{timestamp}.wav" - ) - + # Convert captured audio to PcmData pcm_bytes = b"".join(result.speeches) pcm = PcmData.from_bytes( pcm_bytes, sample_rate=sample_rate, channels=channels, format="s16" ) - with open(outfile_path, "wb") as f: - f.write(pcm.to_wav_bytes()) - # Optional playback - if os.environ.get(play_env) == "1" and shutil.which("ffplay"): - proc = await asyncio.create_subprocess_exec( - "ffplay", - "-autoexit", - "-nodisp", - "-hide_banner", - "-loglevel", - "error", - outfile_path, + # Generate a descriptive filename if not provided + if outfile_path is None: + tmpdir = tempfile.gettempdir() + timestamp = int(time.time()) + outfile_path = os.path.join( + tmpdir, f"tts_manual_test_{tts.__class__.__name__}_{timestamp}.wav" ) - try: - await asyncio.wait_for(proc.wait(), timeout=30.0) - except asyncio.TimeoutError: - proc.kill() - return outfile_path + # Use utility function to write WAV and optionally play + return await play_pcm_with_ffplay(pcm, outfile_path=outfile_path, timeout_s=30.0) diff --git a/agents-core/vision_agents/core/tts/tts.py b/agents-core/vision_agents/core/tts/tts.py index 2f0c8f2e..aabb05db 100644 --- a/agents-core/vision_agents/core/tts/tts.py +++ b/agents-core/vision_agents/core/tts/tts.py @@ -25,7 +25,6 @@ tts_events_emitted, ) from ..edge.types import PcmData -import numpy as np logger = logging.getLogger(__name__) @@ -61,17 +60,14 @@ def __init__(self, provider_name: Optional[str] = None): self.provider_name = provider_name or self.__class__.__name__ self.events = EventManager() self.events.register_events_from_module(events, ignore_not_compatible=True) + # Desired output audio format (what downstream audio track expects) - # Agent can override via set_output_format self._desired_sample_rate: int = 16000 self._desired_channels: int = 1 self._desired_format: AudioFormat = AudioFormat.PCM_S16 - # Native/provider audio format default (used only if plugin returns raw bytes) - self._native_sample_rate: int = 16000 - self._native_channels: int = 1 - self._native_format: AudioFormat = AudioFormat.PCM_S16 + # Persistent resampler to avoid discontinuities between chunks - self._resampler = None + self._resampler: Optional[av.AudioResampler] = None self._resampler_input_rate: Optional[int] = None self._resampler_input_channels: Optional[int] = None @@ -113,7 +109,11 @@ def _get_resampler(self, input_rate: int, input_channels: int): PyAV AudioResampler instance """ - if self._resampler is not None and self._resampler_input_rate == input_rate and self._resampler_input_channels == input_channels: + if ( + self._resampler is not None + and self._resampler_input_rate == input_rate + and self._resampler_input_channels == input_channels + ): return self._resampler in_layout = "mono" if input_channels == 1 else "stereo" @@ -135,40 +135,21 @@ def _get_resampler(self, input_rate: int, input_channels: int): return self._resampler - def _normalize_to_pcm(self, item: Union[bytes, bytearray, PcmData, Any]) -> PcmData: - """Normalize a chunk to PcmData using the native provider format.""" - if isinstance(item, PcmData): - return item - data = getattr(item, "data", item) - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError("Chunk is not bytes or PcmData") - fmt = ( - self._native_format.value - if hasattr(self._native_format, "value") - else "s16" - ) - return PcmData.from_bytes( - bytes(data), - sample_rate=self._native_sample_rate, - channels=self._native_channels, - format=fmt, - ) - async def _iter_pcm(self, resp: Any) -> AsyncGenerator[PcmData, None]: """Yield PcmData chunks from a provider response of various shapes.""" # Single buffer or PcmData - if isinstance(resp, (bytes, bytearray, PcmData)): - yield self._normalize_to_pcm(resp) + if isinstance(resp, (PcmData,)): + yield resp return # Async iterable if hasattr(resp, "__aiter__"): async for item in resp: - yield self._normalize_to_pcm(item) + yield item return - # Sync iterable (avoid treating bytes-like as iterable of ints) - if hasattr(resp, "__iter__") and not isinstance(resp, (str, bytes, bytearray)): + # Sync iterable + if hasattr(resp, "__iter__"): for item in resp: - yield self._normalize_to_pcm(item) + yield item return raise TypeError(f"Unsupported return type from stream_audio: {type(resp)}") @@ -297,9 +278,9 @@ async def send( chunk_index = 0 # Fast-path: single buffer -> mark final - if isinstance(response, (bytes, bytearray, PcmData)): + if isinstance(response, (PcmData,)): bytes_len, dur_ms = self._emit_chunk( - self._normalize_to_pcm(response), 0, True, synthesis_id, text, user + response, 0, True, synthesis_id, text, user ) total_audio_bytes += bytes_len total_audio_ms += dur_ms From 5261a5f68d8f85116c66f1566b8cf0746f928121 Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 23:06:47 +0200 Subject: [PATCH 12/15] remove telemtry code that does not belong --- .../core/observability/__init__.py | 2 - .../core/observability/metrics.py | 92 ++++++++++--------- 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/agents-core/vision_agents/core/observability/__init__.py b/agents-core/vision_agents/core/observability/__init__.py index fe1420c0..f7f9edbf 100644 --- a/agents-core/vision_agents/core/observability/__init__.py +++ b/agents-core/vision_agents/core/observability/__init__.py @@ -17,7 +17,6 @@ tts_errors, tts_events_emitted, inflight_ops, - CALL_ATTRS, ) __all__ = [ @@ -33,5 +32,4 @@ "tts_errors", "tts_events_emitted", "inflight_ops", - "CALL_ATTRS", ] diff --git a/agents-core/vision_agents/core/observability/metrics.py b/agents-core/vision_agents/core/observability/metrics.py index 066e215a..f5575e8d 100644 --- a/agents-core/vision_agents/core/observability/metrics.py +++ b/agents-core/vision_agents/core/observability/metrics.py @@ -1,41 +1,57 @@ -# otel_setup.py -from opentelemetry import trace, metrics -from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter -from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter -from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +"""OpenTelemetry observability instrumentation for vision-agents library. + +This module defines metrics and tracers for the vision-agents library. It does NOT +configure OpenTelemetry providers - that is the responsibility of applications using +this library. + +For applications using this library: + To enable telemetry, configure OpenTelemetry in your application before importing + vision-agents components: -# Point these at your collector (default shown) -OTLP_ENDPOINT = "http://localhost:4317" + ```python + from opentelemetry import trace, metrics + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader + from opentelemetry.sdk.resources import Resource -resource = Resource.create( - { - "service.name": "voice-agent", + # Configure your service + resource = Resource.create({ + "service.name": "my-voice-app", "service.version": "1.0.0", - } -) + }) -# --- Traces --- -tracer_provider = TracerProvider(resource=resource) -tracer_provider.add_span_processor( - BatchSpanProcessor(OTLPSpanExporter(endpoint=OTLP_ENDPOINT)) -) -trace.set_tracer_provider(tracer_provider) -tracer = trace.get_tracer(__name__) + # Setup trace provider + trace_provider = TracerProvider(resource=resource) + trace_provider.add_span_processor( + BatchSpanProcessor(OTLPSpanExporter(endpoint="http://localhost:4317")) + ) + trace.set_tracer_provider(trace_provider) -# --- Metrics --- -metric_reader = PeriodicExportingMetricReader( - OTLPMetricExporter(endpoint=OTLP_ENDPOINT) -) -meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) -metrics.set_meter_provider(meter_provider) -meter = metrics.get_meter(__name__) + # Setup metrics provider + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter(endpoint="http://localhost:4317") + ) + metrics_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + metrics.set_meter_provider(metrics_provider) + # Now import and use vision-agents + from vision_agents.core.tts import TTS + ``` -meter = metrics.get_meter("voice-agent.latency") + If no providers are configured, metrics and traces will be no-ops. +""" + +from opentelemetry import trace, metrics + +# Get tracer and meter using the library name +# These will use whatever providers the application has configured +# If no providers are configured, they will be no-ops +tracer = trace.get_tracer("vision_agents.core") +meter = metrics.get_meter("vision_agents.core") stt_latency_ms = meter.create_histogram( "stt.latency.ms", unit="ms", description="Total STT latency" @@ -65,17 +81,3 @@ inflight_ops = meter.create_up_down_counter( "voice.ops.inflight", description="Inflight voice ops" ) - -CALL_ATTRS = { - "provider": "deepgram", # or "whisper", "revai", "gcloud", etc. - "model": "nova-2", # your model id - "lang": "en-US", # BCP-47 / ISO code - "transport": "http", # or "websocket", "grpc" - "streaming": True, # True/False -} - -with tracer.start_as_current_span("stt.request", kind=trace.SpanKind.CLIENT) as span: - pass - -span = tracer.start_span("stt.request") -span.end() From 1707aa49acffcf1893dfaa2ac599a3e7af7c948c Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 23:13:17 +0200 Subject: [PATCH 13/15] remove debug test --- tests/test_resample_quality.py | 146 --------------------------------- 1 file changed, 146 deletions(-) delete mode 100644 tests/test_resample_quality.py diff --git a/tests/test_resample_quality.py b/tests/test_resample_quality.py deleted file mode 100644 index e19c7c50..00000000 --- a/tests/test_resample_quality.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Test to investigate resampling quality issues.""" - -import numpy as np -from vision_agents.core.edge.types import PcmData -import av - - -def test_compare_resampling_methods(): - """Compare PyAV resampling with scipy for quality.""" - # Create 1 second of clean sine wave at 16kHz - sample_rate_in = 16000 - sample_rate_out = 48000 - duration = 1.0 - freq = 440 # A4 note - - num_samples_in = int(sample_rate_in * duration) - t = np.linspace(0, duration, num_samples_in, dtype=np.float64) - - # Generate clean sine wave - sine_wave = np.sin(2 * np.pi * freq * t) - audio_int16 = (sine_wave * 32767).astype(np.int16) - - # Method 1: PyAV (current implementation) - pcm_16k = PcmData( - samples=audio_int16, sample_rate=sample_rate_in, format="s16", channels=1 - ) - pcm_48k_pyav = pcm_16k.resample(sample_rate_out, target_channels=1) - - print("\n=== PyAV Resampler ===") - print(f"Input: {len(audio_int16)} samples @ {sample_rate_in}Hz") - print(f"Output: {len(pcm_48k_pyav.samples)} samples @ {sample_rate_out}Hz") - print(f"Output dtype: {pcm_48k_pyav.samples.dtype}") - print(f"Output shape: {pcm_48k_pyav.samples.shape}") - - # Check for clipping or artifacts - pyav_samples = ( - pcm_48k_pyav.samples.flatten() - if pcm_48k_pyav.samples.ndim > 1 - else pcm_48k_pyav.samples - ) - print(f"Output min: {pyav_samples.min()}, max: {pyav_samples.max()}") - - # Check for discontinuities (potential clicks) - diffs = np.abs(np.diff(pyav_samples.astype(np.float32))) - max_jump = np.max(diffs) - mean_jump = np.mean(diffs) - print(f"Max sample-to-sample jump: {max_jump:.1f}") - print(f"Mean sample-to-sample jump: {mean_jump:.1f}") - - # Large jumps indicate clicks - large_jumps = np.sum(diffs > 10000) - print(f"Number of large jumps (>10000): {large_jumps}") - - # Method 2: Try scipy for comparison - try: - from scipy import signal - - # Resample using scipy's high-quality resampler - num_samples_out = int(len(audio_int16) * sample_rate_out / sample_rate_in) - audio_float = audio_int16.astype(np.float32) / 32768.0 - resampled_scipy = signal.resample(audio_float, num_samples_out) - resampled_scipy_int16 = (np.clip(resampled_scipy, -1.0, 1.0) * 32767).astype( - np.int16 - ) - - print("\n=== SciPy Resampler ===") - print(f"Output: {len(resampled_scipy_int16)} samples @ {sample_rate_out}Hz") - print( - f"Output min: {resampled_scipy_int16.min()}, max: {resampled_scipy_int16.max()}" - ) - - diffs_scipy = np.abs(np.diff(resampled_scipy_int16.astype(np.float32))) - max_jump_scipy = np.max(diffs_scipy) - mean_jump_scipy = np.mean(diffs_scipy) - print(f"Max sample-to-sample jump: {max_jump_scipy:.1f}") - print(f"Mean sample-to-sample jump: {mean_jump_scipy:.1f}") - - large_jumps_scipy = np.sum(diffs_scipy > 10000) - print(f"Number of large jumps (>10000): {large_jumps_scipy}") - - # Save both for manual inspection - import wave - import io - - def save_wav(samples, sr, filename): - buf = io.BytesIO() - with wave.open(buf, "wb") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(sr) - wf.writeframes(samples.tobytes()) - with open(filename, "wb") as f: - f.write(buf.getvalue()) - - save_wav(audio_int16, sample_rate_in, "/tmp/original_16k.wav") - save_wav(pyav_samples.astype(np.int16), sample_rate_out, "/tmp/pyav_48k.wav") - save_wav(resampled_scipy_int16, sample_rate_out, "/tmp/scipy_48k.wav") - - print("\nWAV files saved to /tmp/ for comparison") - - except ImportError: - print("\nSciPy not available for comparison") - - -def test_pyav_resampler_settings(): - """Check if PyAV resampler has quality settings we're missing.""" - sample_rate_in = 16000 - sample_rate_out = 48000 - num_samples = 16000 - - # Create test signal - t = np.linspace(0, 1.0, num_samples, dtype=np.float64) - audio = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) - - # Create frame - frame = av.AudioFrame.from_ndarray( - audio.reshape(1, -1), format="s16p", layout="mono" - ) - frame.sample_rate = sample_rate_in - - # Try different resampler configurations - print("\n=== Testing PyAV Resampler Options ===") - - # Default resampler - resampler_default = av.AudioResampler( - format="s16", layout="mono", rate=sample_rate_out - ) - - print("Default resampler created") - print(f"Resampler: {resampler_default}") - - # Check if there are any quality options available - # Note: PyAV/FFmpeg's libswresample has quality options but might not be exposed - - frames = resampler_default.resample(frame) - if frames: - result = frames[0].to_ndarray().flatten() - print(f"Default output: {len(result)} samples") - - diffs = np.abs(np.diff(result.astype(np.float32))) - print(f"Max jump: {np.max(diffs):.1f}, Mean jump: {np.mean(diffs):.1f}") - - -if __name__ == "__main__": - test_compare_resampling_methods() - test_pyav_resampler_settings() From 1b667867b0bc572773bbb2d9cf971bbbdb368fba Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 23:22:03 +0200 Subject: [PATCH 14/15] better code --- agents-core/vision_agents/core/tts/manual_test.py | 7 +++---- agents-core/vision_agents/core/tts/tts.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/agents-core/vision_agents/core/tts/manual_test.py b/agents-core/vision_agents/core/tts/manual_test.py index 836ab05b..b6d2c727 100644 --- a/agents-core/vision_agents/core/tts/manual_test.py +++ b/agents-core/vision_agents/core/tts/manual_test.py @@ -20,11 +20,11 @@ async def manual_tts_to_wav( ) -> str: """Generate TTS audio to a WAV file and optionally play with ffplay. - - Creates the TTS instance via `tts_factory()`. - - Sets desired output format via `set_output_format(sample_rate, channels)`. + - Receives a TTS instance. + - Configures desired output format via `set_output_format(sample_rate, channels)`. - Sends `text` and captures TTSAudioEvent chunks. - Writes a WAV (s16) file and returns the path. - - If env `play_env` is set to "1" and `ffplay` exists, it plays the file. + - If `ffplay` exists, it plays the file. Args: tts: the TTS instance. @@ -33,7 +33,6 @@ async def manual_tts_to_wav( text: text to synthesize. outfile_path: optional absolute path for the WAV file; if None, temp path. timeout_s: timeout for first audio to arrive. - play_env: env var name controlling playback (default: FFPLAY). Returns: Path to written WAV file. diff --git a/agents-core/vision_agents/core/tts/tts.py b/agents-core/vision_agents/core/tts/tts.py index aabb05db..fc984c7c 100644 --- a/agents-core/vision_agents/core/tts/tts.py +++ b/agents-core/vision_agents/core/tts/tts.py @@ -144,11 +144,21 @@ async def _iter_pcm(self, resp: Any) -> AsyncGenerator[PcmData, None]: # Async iterable if hasattr(resp, "__aiter__"): async for item in resp: + if not isinstance(item, PcmData): + raise TypeError( + "stream_audio must yield PcmData; wrap provider bytes via PcmData.from_response in the plugin" + ) yield item return # Sync iterable - if hasattr(resp, "__iter__"): + if hasattr(resp, "__iter__") and not isinstance( + resp, (bytes, bytearray, memoryview, str) + ): for item in resp: + if not isinstance(item, PcmData): + raise TypeError( + "stream_audio must yield PcmData; wrap provider bytes via PcmData.from_response in the plugin" + ) yield item return raise TypeError(f"Unsupported return type from stream_audio: {type(resp)}") From 2c0022881870d0e1491db8833a1ee60f3bd0c7a9 Mon Sep 17 00:00:00 2001 From: Tommaso Barbugli Date: Fri, 24 Oct 2025 23:38:58 +0200 Subject: [PATCH 15/15] fix tests --- agents-core/vision_agents/core/edge/types.py | 8 + tests/test_utils.py | 274 +++++++++++-------- 2 files changed, 168 insertions(+), 114 deletions(-) diff --git a/agents-core/vision_agents/core/edge/types.py b/agents-core/vision_agents/core/edge/types.py index 95ba2d6e..89f64e0e 100644 --- a/agents-core/vision_agents/core/edge/types.py +++ b/agents-core/vision_agents/core/edge/types.py @@ -418,6 +418,14 @@ def resample( ) resampled_samples = raw_array.reshape(ch, -1) + # Flatten mono arrays to 1D for consistency + if ( + ch == 1 + and isinstance(resampled_samples, np.ndarray) + and resampled_samples.ndim > 1 + ): + resampled_samples = resampled_samples.flatten() + # Ensure int16 dtype for s16 if ( isinstance(resampled_samples, np.ndarray) diff --git a/tests/test_utils.py b/tests/test_utils.py index 89c3bfa5..c1c4068b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,89 +7,97 @@ class TestParseInstructions: """Test suite for the parse_instructions function.""" - + def test_parse_instructions_no_mentions(self): """Test parsing text with no @ mentions.""" text = "This is a simple instruction without any mentions." result = parse_instructions(text) - + assert isinstance(result, Instructions) assert result.input_text == text assert result.markdown_contents == {} - + def test_parse_instructions_single_mention(self): """Test parsing text with a single @ mention.""" text = "Please read @nonexistent.md for more information." result = parse_instructions(text) - + assert result.input_text == text assert result.markdown_contents == {"nonexistent.md": ""} # File doesn't exist - + def test_parse_instructions_multiple_mentions(self): """Test parsing text with multiple @ mentions.""" text = "Check @file1.md and @file2.md for details. Also see @guide.md." result = parse_instructions(text) - + assert result.input_text == text - assert result.markdown_contents == {"file1.md": "", "file2.md": "", "guide.md": ""} - + assert result.markdown_contents == { + "file1.md": "", + "file2.md": "", + "guide.md": "", + } + def test_parse_instructions_duplicate_mentions(self): """Test parsing text with duplicate @ mentions.""" text = "Read @nonexistent.md and then @nonexistent.md again." result = parse_instructions(text) - + assert result.input_text == text # Should only include unique filenames assert result.markdown_contents == {"nonexistent.md": ""} - + def test_parse_instructions_non_markdown_mentions(self): """Test parsing text with @ mentions that are not markdown files.""" text = "Check @user123 and @file.txt for information." result = parse_instructions(text) - + assert result.input_text == text # Should only capture .md files assert result.markdown_contents == {} - + def test_parse_instructions_mixed_mentions(self): """Test parsing text with both markdown and non-markdown @ mentions.""" text = "Check @user123, @nonexistent.md, and @config.txt for details." result = parse_instructions(text) - + assert result.input_text == text # Should only capture .md files assert result.markdown_contents == {"nonexistent.md": ""} - + def test_parse_instructions_complex_filenames(self): """Test parsing text with complex markdown filenames.""" text = "See @my-file.md, @file_with_underscores.md, and @file-with-dashes.md." result = parse_instructions(text) - + assert result.input_text == text - assert result.markdown_contents == {"my-file.md": "", "file_with_underscores.md": "", "file-with-dashes.md": ""} - + assert result.markdown_contents == { + "my-file.md": "", + "file_with_underscores.md": "", + "file-with-dashes.md": "", + } + def test_parse_instructions_edge_cases(self): """Test parsing text with edge cases.""" # Empty string result = parse_instructions("") assert result.input_text == "" assert result.markdown_contents == {} - + # Only @ symbol result = parse_instructions("@") assert result.input_text == "@" assert result.markdown_contents == {} - + # @ without filename result = parse_instructions("Check @ for details") assert result.input_text == "Check @ for details" assert result.markdown_contents == {} - + # @ with spaces in filename (should not match) result = parse_instructions("Check @my file.md for details") assert result.input_text == "Check @my file.md for details" assert result.markdown_contents == {} - + def test_parse_instructions_case_sensitivity(self): """Test that @ mentions with different cases are extracted separately.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -97,36 +105,40 @@ def test_parse_instructions_case_sensitivity(self): # because macOS and Windows use case-insensitive filesystems by default file1_path = os.path.join(temp_dir, "Guide.md") file2_path = os.path.join(temp_dir, "Help.md") - - with open(file1_path, 'w', encoding='utf-8') as f: + + with open(file1_path, "w", encoding="utf-8") as f: f.write("# Guide Content") - - with open(file2_path, 'w', encoding='utf-8') as f: + + with open(file2_path, "w", encoding="utf-8") as f: f.write("# Help Content") - + # Test that the parser correctly extracts both case variations from text # even if they refer to the same file on case-insensitive filesystems text = "Check @Guide.md and @guide.md and @Help.md for information." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text # Parser should extract all mentioned filenames assert "Guide.md" in result.markdown_contents - assert "guide.md" in result.markdown_contents + assert "guide.md" in result.markdown_contents assert "Help.md" in result.markdown_contents # On case-insensitive systems, Guide.md and guide.md will have same content # but the parser still tracks them separately by their @ mention assert len(result.markdown_contents["Guide.md"]) > 0 assert len(result.markdown_contents["Help.md"]) > 0 - + def test_parse_instructions_special_characters(self): """Test parsing with special characters in filenames.""" text = "Check @file-1.md, @file_2.md, and @file.3.md for details." result = parse_instructions(text) - + assert result.input_text == text - assert result.markdown_contents == {"file-1.md": "", "file_2.md": "", "file.3.md": ""} - + assert result.markdown_contents == { + "file-1.md": "", + "file_2.md": "", + "file.3.md": "", + } + def test_parse_instructions_multiline_text(self): """Test parsing multiline text with @ mentions.""" text = """Please review the following files: @@ -135,82 +147,96 @@ def test_parse_instructions_multiline_text(self): - @troubleshooting.md for common issues """ result = parse_instructions(text) - + assert result.input_text == text - assert result.markdown_contents == {"setup.md": "", "api.md": "", "troubleshooting.md": ""} + assert result.markdown_contents == { + "setup.md": "", + "api.md": "", + "troubleshooting.md": "", + } class TestInstructions: """Test suite for the Instructions dataclass.""" - + def test_instructions_initialization(self): """Test Instructions dataclass initialization.""" input_text = "Test instruction" markdown_contents = {"file1.md": "# File 1 content"} - + instructions = Instructions(input_text, markdown_contents) - + assert instructions.input_text == input_text assert instructions.markdown_contents == markdown_contents - + def test_instructions_empty_markdown_files(self): """Test Instructions with empty markdown files dict.""" input_text = "Simple instruction" markdown_contents = {} - + instructions = Instructions(input_text, markdown_contents) - + assert instructions.input_text == input_text assert instructions.markdown_contents == {} - + def test_instructions_equality(self): """Test Instructions equality comparison.""" instructions1 = Instructions("test", {"file.md": "content"}) instructions2 = Instructions("test", {"file.md": "content"}) instructions3 = Instructions("different", {"file.md": "content"}) - + assert instructions1 == instructions2 assert instructions1 != instructions3 class TestParseInstructionsFileReading: """Test suite for file reading functionality in parse_instructions.""" - + def test_parse_instructions_with_existing_files(self): """Test parsing with actual markdown files that exist.""" with tempfile.TemporaryDirectory() as temp_dir: # Create test markdown files file1_path = os.path.join(temp_dir, "readme.md") file2_path = os.path.join(temp_dir, "guide.md") - - with open(file1_path, 'w', encoding='utf-8') as f: + + with open(file1_path, "w", encoding="utf-8") as f: f.write("# README\n\nThis is a test readme file.") - - with open(file2_path, 'w', encoding='utf-8') as f: + + with open(file2_path, "w", encoding="utf-8") as f: f.write("# Guide\n\nThis is a test guide file.") - + text = "Please read @readme.md and @guide.md for information." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text - assert result.markdown_contents["readme.md"] == "# README\n\nThis is a test readme file." - assert result.markdown_contents["guide.md"] == "# Guide\n\nThis is a test guide file." - + assert ( + result.markdown_contents["readme.md"] + == "# README\n\nThis is a test readme file." + ) + assert ( + result.markdown_contents["guide.md"] + == "# Guide\n\nThis is a test guide file." + ) + def test_parse_instructions_with_mixed_existing_nonexisting_files(self): """Test parsing with mix of existing and non-existing files.""" with tempfile.TemporaryDirectory() as temp_dir: # Create only one test file file1_path = os.path.join(temp_dir, "readme.md") - with open(file1_path, 'w', encoding='utf-8') as f: + with open(file1_path, "w", encoding="utf-8") as f: f.write("# README\n\nThis file exists.") - + text = "Check @readme.md and @nonexistent.md for details." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text - assert result.markdown_contents["readme.md"] == "# README\n\nThis file exists." - assert result.markdown_contents["nonexistent.md"] == "" # Empty for non-existing file - + assert ( + result.markdown_contents["readme.md"] == "# README\n\nThis file exists." + ) + assert ( + result.markdown_contents["nonexistent.md"] == "" + ) # Empty for non-existing file + def test_parse_instructions_with_custom_base_dir(self): """Test parsing with custom base directory.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -218,70 +244,77 @@ def test_parse_instructions_with_custom_base_dir(self): subdir = os.path.join(temp_dir, "docs") os.makedirs(subdir) file_path = os.path.join(subdir, "api.md") - - with open(file_path, 'w', encoding='utf-8') as f: + + with open(file_path, "w", encoding="utf-8") as f: f.write("# API Documentation\n\nThis is the API docs.") - + text = "See @api.md for API information." result = parse_instructions(text, base_dir=subdir) - + assert result.input_text == text - assert result.markdown_contents["api.md"] == "# API Documentation\n\nThis is the API docs." - + assert ( + result.markdown_contents["api.md"] + == "# API Documentation\n\nThis is the API docs." + ) + def test_parse_instructions_file_read_error_handling(self): """Test handling of file read errors.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a file that will cause read errors (permission issues, etc.) file_path = os.path.join(temp_dir, "readme.md") - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.write("test content") - + # Make file unreadable (this might not work on all systems) try: os.chmod(file_path, 0o000) # No permissions - + text = "Read @readme.md for information." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text - assert result.markdown_contents["readme.md"] == "" # Empty due to read error + assert ( + result.markdown_contents["readme.md"] == "" + ) # Empty due to read error finally: # Restore permissions for cleanup os.chmod(file_path, 0o644) - + def test_parse_instructions_unicode_content(self): """Test parsing with unicode content in markdown files.""" with tempfile.TemporaryDirectory() as temp_dir: file_path = os.path.join(temp_dir, "unicode.md") - + # Write unicode content - unicode_content = "# Unicode Test\n\nHello 世界! 🌍\n\nThis has émojis and àccénts." - with open(file_path, 'w', encoding='utf-8') as f: + unicode_content = ( + "# Unicode Test\n\nHello 世界! 🌍\n\nThis has émojis and àccénts." + ) + with open(file_path, "w", encoding="utf-8") as f: f.write(unicode_content) - + text = "Check @unicode.md for unicode content." result = parse_instructions(text, base_dir=temp_dir) - + assert result.input_text == text assert result.markdown_contents["unicode.md"] == unicode_content - + def test_parse_instructions_default_base_dir(self): """Test that default base directory is current working directory.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a test file file_path = os.path.join(temp_dir, "readme.md") - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.write("# Test readme content") - + # Change to temp directory to test default base_dir original_cwd = os.getcwd() try: os.chdir(temp_dir) - + # This test verifies that when no base_dir is provided, it uses os.getcwd() text = "Read @readme.md for information." result = parse_instructions(text) # No base_dir provided - + assert result.input_text == text # Content will not be empty since readme.md exists in current directory assert "readme.md" in result.markdown_contents @@ -294,91 +327,103 @@ def test_parse_instructions_default_base_dir(self): class TestPcmDataMethods: """Test suite for PcmData class methods.""" - + def test_pcm_data_from_bytes(self): """Test PcmData.from_bytes class method.""" # Create test audio data (1 second of 16kHz audio) test_samples = np.random.randint(-32768, 32767, 16000, dtype=np.int16) audio_bytes = test_samples.tobytes() - + pcm_data = PcmData.from_bytes(audio_bytes, sample_rate=16000, format="s16") - + assert pcm_data.sample_rate == 16000 assert pcm_data.format == "s16" assert np.array_equal(pcm_data.samples, test_samples) assert pcm_data.duration == 1.0 # 1 second - + def test_pcm_data_resample_same_rate(self): """Test resampling when source and target rates are the same.""" test_samples = np.random.randint(-32768, 32767, 16000, dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=16000, format="s16") - + resampled = pcm_data.resample(target_sample_rate=16000) - + # Should return the same data assert resampled.sample_rate == 16000 assert np.array_equal(resampled.samples, test_samples) assert resampled.format == "s16" - + def test_pcm_data_resample_24khz_to_48khz(self): """Test resampling from 24kHz to 48kHz (Gemini use case).""" # Create test audio data (1 second of 24kHz audio) test_samples = np.random.randint(-32768, 32767, 24000, dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=24000, format="s16") - + resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.sample_rate == 48000 assert resampled.format == "s16" # Should have approximately double the samples (24k -> 48k) - assert abs(len(resampled.samples) - 48000) < 100 # Allow some tolerance + # Handle both 1D and 2D arrays + num_samples = ( + resampled.samples.shape[-1] + if resampled.samples.ndim > 1 + else len(resampled.samples) + ) + assert abs(num_samples - 48000) < 100 # Allow some tolerance # Duration should be approximately the same assert abs(resampled.duration - 1.0) < 0.1 - + def test_pcm_data_resample_48khz_to_16khz(self): """Test resampling from 48kHz to 16kHz.""" # Create test audio data (1 second of 48kHz audio) test_samples = np.random.randint(-32768, 32767, 48000, dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=48000, format="s16") - + resampled = pcm_data.resample(target_sample_rate=16000) - + assert resampled.sample_rate == 16000 assert resampled.format == "s16" # Should have approximately 1/3 the samples (48k -> 16k) - assert abs(len(resampled.samples) - 16000) < 100 # Allow some tolerance + # Handle both 1D and 2D arrays + num_samples = ( + resampled.samples.shape[-1] + if resampled.samples.ndim > 1 + else len(resampled.samples) + ) + assert abs(num_samples - 16000) < 100 # Allow some tolerance # Duration should be approximately the same assert abs(resampled.duration - 1.0) < 0.1 - + def test_pcm_data_resample_preserves_metadata(self): """Test that resampling preserves PTS, DTS, and time_base metadata.""" test_samples = np.random.randint(-32768, 32767, 16000, dtype=np.int16) pcm_data = PcmData( - samples=test_samples, - sample_rate=16000, + samples=test_samples, + sample_rate=16000, format="s16", pts=1000, dts=950, - time_base=0.001 + time_base=0.001, ) - + resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.pts == 1000 assert resampled.dts == 950 assert resampled.time_base == 0.001 assert abs(resampled.pts_seconds - 1.0) < 0.0001 assert abs(resampled.dts_seconds - 0.95) < 0.0001 - + def test_pcm_data_resample_handles_1d_array(self): """Test that resampling handles 1D arrays correctly (fixes ndim error).""" # Create test audio data (1 second of 24kHz audio) - 1D array test_samples = np.random.randint(-32768, 32767, 24000, dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=24000, format="s16") - + # This should now work without the ndim error resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.sample_rate == 48000 assert resampled.format == "s16" # Should have approximately double the samples (24k -> 48k) @@ -387,16 +432,16 @@ def test_pcm_data_resample_handles_1d_array(self): assert abs(resampled.duration - 1.0) < 0.1 # Output should be 1D array assert resampled.samples.ndim == 1 - + def test_pcm_data_resample_handles_2d_array(self): """Test that resampling handles 2D arrays correctly.""" # Create test audio data (1 second of 24kHz audio) - 2D array (channels, samples) test_samples = np.random.randint(-32768, 32767, (1, 24000), dtype=np.int16) pcm_data = PcmData(samples=test_samples, sample_rate=24000, format="s16") - + # This should work with 2D arrays too resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.sample_rate == 48000 assert resampled.format == "s16" # Should have approximately double the samples (24k -> 48k) @@ -405,17 +450,17 @@ def test_pcm_data_resample_handles_2d_array(self): assert abs(resampled.duration - 1.0) < 0.1 # Output should be 1D array (flattened) assert resampled.samples.ndim == 1 - + def test_pcm_data_from_bytes_and_resample_chain(self): """Test chaining from_bytes and resample methods (Gemini use case).""" # Create test audio data (1 second of 24kHz audio) test_samples = np.random.randint(-32768, 32767, 24000, dtype=np.int16) audio_bytes = test_samples.tobytes() - + # Chain the methods like in realtime2.py pcm_data = PcmData.from_bytes(audio_bytes, sample_rate=24000, format="s16") resampled_pcm = pcm_data.resample(target_sample_rate=48000) - + assert pcm_data.sample_rate == 24000 assert resampled_pcm.sample_rate == 48000 assert resampled_pcm.format == "s16" @@ -423,16 +468,18 @@ def test_pcm_data_from_bytes_and_resample_chain(self): assert abs(len(resampled_pcm.samples) - 48000) < 100 # Allow some tolerance # Duration should be approximately the same assert abs(resampled_pcm.duration - 1.0) < 0.1 - + def test_pcm_data_resample_av_array_shape_fix(self): """Test that fixes the AV library array shape error (channels, samples).""" # Create test audio data that would cause the "Expected packed array.shape[0] to equal 1" error - test_samples = np.random.randint(-32768, 32767, 1920, dtype=np.int16) # Small chunk like in the error + test_samples = np.random.randint( + -32768, 32767, 1920, dtype=np.int16 + ) # Small chunk like in the error pcm_data = PcmData(samples=test_samples, sample_rate=24000, format="s16") - + # This should work without the array shape error resampled = pcm_data.resample(target_sample_rate=48000) - + assert resampled.sample_rate == 48000 assert resampled.format == "s16" # Should have approximately double the samples (1920 -> ~3840) @@ -442,4 +489,3 @@ def test_pcm_data_resample_av_array_shape_fix(self): # Shared fixtures for integration tests -