Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions agents-core/vision_agents/core/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import uuid4

import aiortc
import getstream.models
from aiortc import VideoStreamTrack
from getstream.video.rtc import Call
Expand All @@ -15,7 +14,7 @@
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType
from ..edge import sfu_events
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, CallEndedEvent
from ..edge.types import Connection, Participant, PcmData, User
from ..edge.types import Connection, Participant, PcmData, User, OutputAudioTrack
from ..events.manager import EventManager
from ..llm import events as llm_events
from ..llm.events import (
Expand All @@ -32,6 +31,7 @@
from ..stt.events import STTTranscriptEvent, STTErrorEvent
from ..stt.stt import STT
from ..tts.tts import TTS
from ..tts.events import TTSAudioEvent
from ..turn_detection import TurnDetector, TurnStartedEvent, TurnEndedEvent
from ..vad import VAD
from ..vad.events import VADAudioEvent
Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(
self._callback_executed = False
self._track_tasks: Dict[str, asyncio.Task] = {}
self._connection: Optional[Connection] = None
self._audio_track: Optional[aiortc.AudioStreamTrack] = None
self._audio_track: Optional[OutputAudioTrack] = None
self._video_track: Optional[VideoStreamTrack] = None
self._realtime_connection = None
self._pc_track_handler_attached: bool = False
Expand Down Expand Up @@ -307,6 +307,11 @@ async def on_realtime_agent_speech_transcription(
original=event,
)

@self.events.subscribe
async def _on_tts_audio_write_to_output(event: TTSAudioEvent):
if self._audio_track and event and event.audio_data is not None:
await self._audio_track.write(event.audio_data)

@self.events.subscribe
async def on_stt_transcript_event_create_response(event: STTTranscriptEvent):
if self.realtime_mode or not self.llm:
Expand Down Expand Up @@ -1021,19 +1026,19 @@ def _prepare_rtc(self):
self._audio_track = self.llm.output_track
self.logger.info("🎵 Using Realtime provider output track for audio")
else:
# TODO: what if we want to transform audio...
# Get the required framerate and stereo setting from TTS plugin, default to 48000 for WebRTC
if self.tts:
framerate = self.tts.get_required_framerate()
stereo = self.tts.get_required_stereo()
else:
framerate = 48000
stereo = True # Default to stereo for WebRTC
# Default to WebRTC-friendly format unless configured differently
framerate = 48000
stereo = True
self._audio_track = self.edge.create_audio_track(
framerate=framerate, stereo=stereo
)
# Inform TTS of desired output format so it can resample accordingly
if self.tts:
self.tts.set_output_track(self._audio_track)
channels = 2 if stereo else 1
self.tts.set_output_format(
sample_rate=framerate,
channels=channels,
)

# Set up video track if video publishers are available
if self.publish_video:
Expand Down
11 changes: 6 additions & 5 deletions agents-core/vision_agents/core/edge/edge_transport.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""
Abstraction for stream vs other services here
"""

import abc

from typing import TYPE_CHECKING, Any, Optional

import aiortc
from pyee.asyncio import AsyncIOEventEmitter

from vision_agents.core.edge.types import User
from vision_agents.core.edge.types import User, OutputAudioTrack

if TYPE_CHECKING:

pass


Expand All @@ -31,7 +31,7 @@ async def create_user(self, user: User):
pass

@abc.abstractmethod
def create_audio_track(self):
def create_audio_track(self) -> OutputAudioTrack:
pass

@abc.abstractmethod
Expand All @@ -55,6 +55,7 @@ async def create_conversation(self, call: Any, user: User, instructions):
pass

@abc.abstractmethod
def add_track_subscriber(self, track_id: str) -> Optional[aiortc.mediastreams.MediaStreamTrack]:
def add_track_subscriber(
self, track_id: str
) -> Optional[aiortc.mediastreams.MediaStreamTrack]:
pass

Loading