Skip to content

Commit e6767f4

Browse files
authored
Merge branch 'main' into feature/baseten
2 parents 4508013 + ec457b9 commit e6767f4

File tree

5 files changed

+79
-81
lines changed

5 files changed

+79
-81
lines changed

agents-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121

2222
requires-python = ">=3.10"
2323
dependencies = [
24-
"getstream[webrtc,telemetry]>=2.5.9",
24+
"getstream[webrtc,telemetry]>=2.5.11",
2525
"python-dotenv>=1.1.1",
2626
"pillow>=10.4.0", # Compatible with moondream SDK (<11.0.0)
2727
"numpy>=1.24.0",

agents-core/vision_agents/core/agents/agents.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from aiortc import VideoStreamTrack
1414
from getstream.video.rtc import Call
1515

16+
from getstream.video.rtc.participants import ParticipantsState
1617
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType
1718
from .agent_options import AgentOptions, default_agent_options
1819

@@ -23,7 +24,6 @@
2324
TrackRemovedEvent,
2425
CallEndedEvent,
2526
)
26-
from ..edge.sfu_events import ParticipantJoinedEvent
2727
from ..edge.types import Connection, Participant, PcmData, User, OutputAudioTrack
2828
from ..events.manager import EventManager
2929
from ..llm import events as llm_events
@@ -76,7 +76,7 @@ class TrackInfo:
7676
id: str
7777
type: int
7878
processor: str
79-
priority: int # higher goes first
79+
priority: int # higher goes first
8080
participant: Optional[Participant]
8181
track: aiortc.mediastreams.VideoStreamTrack
8282
forwarder: VideoForwarder
@@ -90,6 +90,7 @@ class TrackInfo:
9090
- cleanup events more
9191
"""
9292

93+
9394
class Agent:
9495
"""
9596
Agent class makes it easy to build your own video AI.
@@ -146,6 +147,7 @@ def __init__(
146147
log_level: Optional[int] = logging.INFO,
147148
profiler: Optional[Profiler] = None,
148149
):
150+
self.participants: Optional[ParticipantsState] = None
149151
self.call = None
150152
self._active_processed_track_id: Optional[str] = None
151153
self._active_source_track_id: Optional[str] = None
@@ -211,7 +213,7 @@ def __init__(
211213

212214
# Attach processors that need agent reference
213215
for processor in self.processors:
214-
if hasattr(processor, '_attach_agent'):
216+
if hasattr(processor, "_attach_agent"):
215217
processor._attach_agent(self)
216218

217219
self.events.subscribe(self._on_vad_audio)
@@ -267,22 +269,20 @@ async def on_video_track_added(event: TrackAddedEvent | TrackRemovedEvent):
267269
if event.track_id is None or event.track_type is None or event.user is None:
268270
return
269271
if isinstance(event, TrackRemovedEvent):
270-
asyncio.create_task(self._on_track_removed(event.track_id, event.track_type, event.user))
272+
asyncio.create_task(
273+
self._on_track_removed(event.track_id, event.track_type, event.user)
274+
)
271275
else:
272-
asyncio.create_task(self._on_track_added(event.track_id, event.track_type, event.user))
276+
asyncio.create_task(
277+
self._on_track_added(event.track_id, event.track_type, event.user)
278+
)
273279

274280
# audio event for the user talking to the AI
275281
@self.edge.events.subscribe
276282
async def on_audio_received(event: AudioReceivedEvent):
277283
if event.participant is not None:
278284
await self._reply_to_audio(event.pcm_data, event.participant)
279285

280-
@self.edge.events.subscribe
281-
async def on_participant_joined(event: ParticipantJoinedEvent):
282-
if event.participant is not None:
283-
self.logger.info(f"Participant {event.participant.user_id} joined")
284-
self.participants[event.participant.session_id] = event.participant
285-
286286
@self.events.subscribe
287287
async def on_stt_transcript_event_create_response(event: STTTranscriptEvent):
288288
if _is_audio_llm(self.llm):
@@ -342,7 +342,7 @@ async def on_stt_transcript_event_sync_conversation(event: STTTranscriptEvent):
342342

343343
@self.events.subscribe
344344
async def on_realtime_user_speech_transcription(
345-
event: RealtimeUserSpeechTranscriptionEvent,
345+
event: RealtimeUserSpeechTranscriptionEvent,
346346
):
347347
self.logger.info(f"🎤 [User transcript]: {event.text}")
348348

@@ -362,7 +362,7 @@ async def on_realtime_user_speech_transcription(
362362

363363
@self.events.subscribe
364364
async def on_realtime_agent_speech_transcription(
365-
event: RealtimeAgentSpeechTranscriptionEvent,
365+
event: RealtimeAgentSpeechTranscriptionEvent,
366366
):
367367
self.logger.info(f"🎤 [Agent transcript]: {event.text}")
368368

@@ -380,9 +380,6 @@ async def on_realtime_agent_speech_transcription(
380380
original=event,
381381
)
382382

383-
384-
385-
386383
@self.llm.events.subscribe
387384
async def on_llm_response_sync_conversation(event: LLMResponseCompletedEvent):
388385
self.logger.info(f"🤖 [LLM response]: {event.text} {event.item_id}")
@@ -424,7 +421,6 @@ async def _handle_output_text_delta(event: LLMResponseChunkEvent):
424421

425422
logger.info("AUDIO: SUB TO EVENTS DONE")
426423

427-
428424
async def simple_response(
429425
self, text: str, participant: Optional[Participant] = None
430426
) -> None:
@@ -441,7 +437,7 @@ async def simple_response(
441437
span.set_attribute("response.original", response.original)
442438

443439
async def simple_audio_response(
444-
self, pcm: PcmData, participant: Optional[Participant] = None
440+
self, pcm: PcmData, participant: Optional[Participant] = None
445441
) -> None:
446442
"""
447443
Makes it easy to subclass how the agent calls the LLM for processing audio
@@ -463,7 +459,9 @@ def subscribe(self, function):
463459
"""
464460
return self.events.subscribe(function)
465461

466-
async def join(self, call: Call, wait_for_participant=True) -> "AgentSessionContextManager":
462+
async def join(
463+
self, call: Call, wait_for_participant=True
464+
) -> "AgentSessionContextManager":
467465
# TODO: validation. join can only be called once
468466
self.logger.info("joining call")
469467
# run start on all subclasses
@@ -504,7 +502,8 @@ async def join(self, call: Call, wait_for_participant=True) -> "AgentSessionCont
504502

505503
with self.span("edge.join"):
506504
connection = await self.edge.join(self, call)
507-
self.participants = connection._connection.participants_state._participant_by_prefix
505+
self.participants = connection.participants
506+
508507
except Exception:
509508
self.clear_call_logging_context()
510509
raise
@@ -548,30 +547,23 @@ async def join(self, call: Call, wait_for_participant=True) -> "AgentSessionCont
548547

549548
async def wait_for_participant(self):
550549
"""wait for a participant other than the AI agent to join"""
551-
# Check if a non-agent participant is already present
552-
if self.call and self.participants:
553-
for p in self.participants.values():
554-
if p.user_id != self.agent_user.id:
555-
self.logger.info(f"Participant {p.user_id} already in call")
556-
return
557550

558-
# If not, wait for one to join
559-
participant_joined = asyncio.Event()
551+
if self.participants is None:
552+
return
560553

561-
@self.edge.events.subscribe
562-
async def on_participant_joined(event: ParticipantJoinedEvent):
563-
if event.participant is not None:
564-
is_agent = event.participant.user_id == self.agent_user.id
554+
participant_joined = asyncio.Event()
565555

566-
self.logger.info(f"Participant {event.participant.user_id} joined is_agent {is_agent}")
567-
if not is_agent:
556+
def on_participants(participants):
557+
for p in participants:
558+
if p.user_id != self.agent_user.id:
568559
participant_joined.set()
569560

570-
# Wait for the event to be set
571-
await participant_joined.wait()
561+
subscription = self.participants.map(on_participants)
572562

573-
# Clean up the subscription
574-
self.edge.events.unsubscribe(on_participant_joined)
563+
try:
564+
await participant_joined.wait()
565+
finally:
566+
subscription.unsubscribe()
575567

576568
async def finish(self):
577569
"""Wait for the call to end gracefully.
@@ -624,7 +616,10 @@ async def _apply(self, function_name: str, *args, **kwargs):
624616
subclasses = [self.llm, self.stt, self.tts, self.turn_detection, self.edge]
625617
subclasses.extend(self.processors)
626618
for subclass in subclasses:
627-
if subclass is not None and getattr(subclass, function_name, None) is not None:
619+
if (
620+
subclass is not None
621+
and getattr(subclass, function_name, None) is not None
622+
):
628623
func = getattr(subclass, function_name)
629624
if func is not None:
630625
await func(*args, **kwargs)
@@ -642,13 +637,6 @@ def _end_tracing(self):
642637
def __aexit__(self, exc_type, exc_val, exc_tb):
643638
self._end_tracing()
644639

645-
646-
647-
648-
649-
650-
651-
652640
async def close(self):
653641
"""Clean up all connections and resources.
654642
@@ -888,7 +876,9 @@ async def _reply_to_audio_consumer(self) -> None:
888876
for processor in self.audio_processors:
889877
if processor is None:
890878
continue
891-
await processor.process_audio(audio_bytes, participant.user_id)
879+
await processor.process_audio(
880+
audio_bytes, participant.user_id
881+
)
892882

893883
# when in Realtime mode call the Realtime directly (non-blocking)
894884
if _is_audio_llm(self.llm):
@@ -937,48 +927,60 @@ async def _image_to_video_processors(self, track_id: str, track_type: int):
937927
for processor in self.image_processors:
938928
try:
939929
pass
940-
#TODO: run this better
941-
#await processor.process_image(
930+
# TODO: run this better
931+
# await processor.process_image(
942932
# img, track_info.participant.user_id, track_id=track_id, track_type=track_type
943-
#)
933+
# )
944934
except Exception as e:
945935
self.logger.error(
946936
f"Error in image processor {type(processor).__name__}: {e}"
947937
)
948938

949-
async def _on_track_removed(self, track_id: str, track_type: int, participant: Participant):
939+
async def _on_track_removed(
940+
self, track_id: str, track_type: int, participant: Participant
941+
):
950942
self._active_video_tracks.pop(track_id)
951943
await self._on_track_change(track_id)
952944

953945
async def _on_track_change(self, track_id: str):
954946
# shared logic between track remove and added
955947
# Select a track. Prioritize screenshare over regular
956948
# This is the track without processing
957-
non_processed_tracks = [t for t in self._active_video_tracks.values() if not t.processor]
958-
source_track = sorted(non_processed_tracks, key=lambda t: t.priority, reverse=True)[0]
949+
non_processed_tracks = [
950+
t for t in self._active_video_tracks.values() if not t.processor
951+
]
952+
source_track = sorted(
953+
non_processed_tracks, key=lambda t: t.priority, reverse=True
954+
)[0]
959955
# assign the tracks that we last used so we can notify of changes...
960956
self._active_source_track_id = source_track.id
961957

962958
await self._track_to_video_processors(source_track)
963959

964-
processed_track = sorted([t for t in self._active_video_tracks.values()], key=lambda t: t.priority, reverse=True)[0]
960+
processed_track = sorted(
961+
[t for t in self._active_video_tracks.values()],
962+
key=lambda t: t.priority,
963+
reverse=True,
964+
)[0]
965965
self._active_processed_track_id = processed_track.id
966966

967967
# See if we have a processed track. If so forward that to LLM
968968
# TODO: this should run in a loop and handle multiple forwarders
969-
#self._image_to_video_processors()
969+
# self._image_to_video_processors()
970970

971971
# If Realtime provider supports video, switch to this new track
972972
if _is_video_llm(self.llm):
973973
await self.llm.watch_video_track(
974974
processed_track.track, shared_forwarder=processed_track.forwarder
975975
)
976976

977-
async def _on_track_added(self, track_id: str, track_type: int, participant: Participant):
977+
async def _on_track_added(
978+
self, track_id: str, track_type: int, participant: Participant
979+
):
978980
# We only process video tracks (camera video or screenshare)
979981
if track_type not in (
980-
TrackType.TRACK_TYPE_VIDEO,
981-
TrackType.TRACK_TYPE_SCREEN_SHARE,
982+
TrackType.TRACK_TYPE_VIDEO,
983+
TrackType.TRACK_TYPE_SCREEN_SHARE,
982984
):
983985
return
984986

@@ -1001,14 +1003,12 @@ async def _on_track_added(self, track_id: str, track_type: int, participant: Par
10011003
processor="",
10021004
track=track,
10031005
participant=participant,
1004-
priority = 1 if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE else 0,
1005-
forwarder = forwarder
1006+
priority=1 if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE else 0,
1007+
forwarder=forwarder,
10061008
)
10071009

10081010
await self._on_track_change(track_id)
10091011

1010-
1011-
10121012
async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None:
10131013
"""Handle turn detection events."""
10141014
# Skip the turn event handling if the model doesn't require TTS or SST audio itself.
@@ -1066,7 +1066,6 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
10661066
# Clear the pending transcript for this speaker
10671067
self._pending_user_transcripts[event.participant.user_id] = ""
10681068

1069-
10701069
@property
10711070
def publish_audio(self) -> bool:
10721071
"""Whether the agent should publish an outbound audio track.
@@ -1241,7 +1240,7 @@ def _prepare_rtc(self):
12411240
track=self._video_track,
12421241
participant=None,
12431242
priority=2,
1244-
forwarder=forwarder
1243+
forwarder=forwarder,
12451244
)
12461245

12471246
self.logger.info("🎥 Video track initialized from video publisher")

agents-core/vision_agents/core/profiling/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pyinstrument
21
import logging
32

43
from vision_agents.core.events import EventManager
@@ -29,6 +28,8 @@ def __init__(self, output_path='./profile.html'):
2928
output_path: Path where the HTML profile report will be saved.
3029
Defaults to './profile.html'.
3130
"""
31+
import pyinstrument
32+
3233
self.output_path = output_path
3334
self.events = EventManager()
3435
self.events.register_events_from_module(events)

0 commit comments

Comments
 (0)