Skip to content

Commit b7c57e9

Browse files
committed
properly type the output track
1 parent 9e9366a commit b7c57e9

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

agents-core/vision_agents/core/agents/agents.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
import logging
33
import time
44
import uuid
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
5+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
66
from uuid import uuid4
77

8-
import aiortc
98
import getstream.models
109
from aiortc import VideoStreamTrack
1110
from getstream.video.rtc import Call
@@ -15,7 +14,7 @@
1514
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType
1615
from ..edge import sfu_events
1716
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, CallEndedEvent
18-
from ..edge.types import Connection, Participant, PcmData, User
17+
from ..edge.types import Connection, Participant, PcmData, User, OutputAudioTrack
1918
from ..events.manager import EventManager
2019
from ..llm import events as llm_events
2120
from ..llm.events import (
@@ -161,7 +160,7 @@ def __init__(
161160
self._callback_executed = False
162161
self._track_tasks: Dict[str, asyncio.Task] = {}
163162
self._connection: Optional[Connection] = None
164-
self._audio_track: Optional[aiortc.AudioStreamTrack] = None
163+
self._audio_track: Optional[OutputAudioTrack] = None
165164
self._video_track: Optional[VideoStreamTrack] = None
166165
self._realtime_connection = None
167166
self._pc_track_handler_attached: bool = False
@@ -308,15 +307,10 @@ async def on_realtime_agent_speech_transcription(
308307
original=event,
309308
)
310309

311-
# Listen for TTS audio events and write audio to the output track
312310
@self.events.subscribe
313-
async def _on_tts_audio(event: TTSAudioEvent):
314-
try:
315-
if self._audio_track and event.audio_data:
316-
track_any = cast(Any, self._audio_track)
317-
await track_any.write(event.audio_data)
318-
except Exception as e:
319-
self.logger.error(f"Error writing TTS audio to track: {e}")
311+
async def _on_tts_audio_write_to_output(event: TTSAudioEvent):
312+
if self._audio_track and event and event.audio_data is not None:
313+
await self._audio_track.write(event.audio_data)
320314

321315
@self.events.subscribe
322316
async def on_stt_transcript_event_create_response(event: STTTranscriptEvent):

agents-core/vision_agents/core/edge/edge_transport.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
"""
22
Abstraction for stream vs other services here
33
"""
4+
45
import abc
56

67
from typing import TYPE_CHECKING, Any, Optional
78

89
import aiortc
910
from pyee.asyncio import AsyncIOEventEmitter
1011

11-
from vision_agents.core.edge.types import User
12+
from vision_agents.core.edge.types import User, OutputAudioTrack
1213

1314
if TYPE_CHECKING:
14-
1515
pass
1616

1717

@@ -31,7 +31,7 @@ async def create_user(self, user: User):
3131
pass
3232

3333
@abc.abstractmethod
34-
def create_audio_track(self):
34+
def create_audio_track(self) -> OutputAudioTrack:
3535
pass
3636

3737
@abc.abstractmethod
@@ -55,6 +55,7 @@ async def create_conversation(self, call: Any, user: User, instructions):
5555
pass
5656

5757
@abc.abstractmethod
58-
def add_track_subscriber(self, track_id: str) -> Optional[aiortc.mediastreams.MediaStreamTrack]:
58+
def add_track_subscriber(
59+
self, track_id: str
60+
) -> Optional[aiortc.mediastreams.MediaStreamTrack]:
5961
pass
60-

agents-core/vision_agents/core/edge/types.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from dataclasses import dataclass
2-
from typing import Any, Optional, NamedTuple, Union, Iterator, AsyncIterator
2+
from typing import (
3+
Any,
4+
Optional,
5+
NamedTuple,
6+
Union,
7+
Iterator,
8+
AsyncIterator,
9+
Protocol,
10+
runtime_checkable,
11+
)
312
import logging
413

514
import numpy as np
@@ -34,6 +43,18 @@ async def close(self):
3443
pass
3544

3645

46+
@runtime_checkable
47+
class OutputAudioTrack(Protocol):
48+
"""
49+
A protocol describing an output audio track, the actual implementation depends on the edge transported used
50+
eg. getstream.video.rtc.audio_track.AudioStreamTrack
51+
"""
52+
53+
async def write(self, data: bytes) -> None: ...
54+
55+
def stop(self) -> None: ...
56+
57+
3758
class PcmData(NamedTuple):
3859
"""
3960
A named tuple representing PCM audio data.

plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from vision_agents.core.edge import EdgeTransport, sfu_events
2424
from vision_agents.plugins.getstream.stream_conversation import StreamConversation
25-
from vision_agents.core.edge.types import Connection, User
25+
from vision_agents.core.edge.types import Connection, User, OutputAudioTrack
2626
from vision_agents.core.events.manager import EventManager
2727
from vision_agents.core.edge import events
2828
from vision_agents.core.utils import get_vision_agents_version
@@ -104,7 +104,7 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
104104
track_type_int = event.payload.type # TrackType enum int from SFU
105105
expected_kind = self._get_webrtc_kind(track_type_int)
106106
track_key = (user_id, session_id, track_type_int)
107-
is_agent_track = (user_id == self.agent_user_id)
107+
is_agent_track = user_id == self.agent_user_id
108108

109109
# First check if track already exists in map (e.g., from previous unpublish/republish)
110110
if track_key in self._track_map:
@@ -288,7 +288,9 @@ async def on_audio_received(pcm: PcmData, participant: Participant):
288288
standardize_connection = StreamConnection(connection)
289289
return standardize_connection
290290

291-
def create_audio_track(self, framerate: int = 48000, stereo: bool = True):
291+
def create_audio_track(
292+
self, framerate: int = 48000, stereo: bool = True
293+
) -> OutputAudioTrack:
292294
return audio_track.AudioStreamTrack(
293295
framerate=framerate, stereo=stereo
294296
) # default to webrtc framerate

0 commit comments

Comments
 (0)