1313from aiortc import VideoStreamTrack
1414from getstream .video .rtc import Call
1515
16+ from getstream .video .rtc .participants import ParticipantsState
1617from getstream .video .rtc .pb .stream .video .sfu .models .models_pb2 import TrackType
1718from .agent_options import AgentOptions , default_agent_options
1819
2324 TrackRemovedEvent ,
2425 CallEndedEvent ,
2526)
26- from ..edge .sfu_events import ParticipantJoinedEvent
2727from ..edge .types import Connection , Participant , PcmData , User , OutputAudioTrack
2828from ..events .manager import EventManager
2929from ..llm import events as llm_events
@@ -76,7 +76,7 @@ class TrackInfo:
7676 id : str
7777 type : int
7878 processor : str
79- priority : int # higher goes first
79+ priority : int # higher goes first
8080 participant : Optional [Participant ]
8181 track : aiortc .mediastreams .VideoStreamTrack
8282 forwarder : VideoForwarder
@@ -90,6 +90,7 @@ class TrackInfo:
9090- cleanup events more
9191"""
9292
93+
9394class Agent :
9495 """
9596 Agent class makes it easy to build your own video AI.
@@ -146,6 +147,7 @@ def __init__(
146147 log_level : Optional [int ] = logging .INFO ,
147148 profiler : Optional [Profiler ] = None ,
148149 ):
150+ self .participants : Optional [ParticipantsState ] = None
149151 self .call = None
150152 self ._active_processed_track_id : Optional [str ] = None
151153 self ._active_source_track_id : Optional [str ] = None
@@ -211,7 +213,7 @@ def __init__(
211213
212214 # Attach processors that need agent reference
213215 for processor in self .processors :
214- if hasattr (processor , ' _attach_agent' ):
216+ if hasattr (processor , " _attach_agent" ):
215217 processor ._attach_agent (self )
216218
217219 self .events .subscribe (self ._on_vad_audio )
@@ -267,22 +269,20 @@ async def on_video_track_added(event: TrackAddedEvent | TrackRemovedEvent):
267269 if event .track_id is None or event .track_type is None or event .user is None :
268270 return
269271 if isinstance (event , TrackRemovedEvent ):
270- asyncio .create_task (self ._on_track_removed (event .track_id , event .track_type , event .user ))
272+ asyncio .create_task (
273+ self ._on_track_removed (event .track_id , event .track_type , event .user )
274+ )
271275 else :
272- asyncio .create_task (self ._on_track_added (event .track_id , event .track_type , event .user ))
276+ asyncio .create_task (
277+ self ._on_track_added (event .track_id , event .track_type , event .user )
278+ )
273279
274280 # audio event for the user talking to the AI
275281 @self .edge .events .subscribe
276282 async def on_audio_received (event : AudioReceivedEvent ):
277283 if event .participant is not None :
278284 await self ._reply_to_audio (event .pcm_data , event .participant )
279285
280- @self .edge .events .subscribe
281- async def on_participant_joined (event : ParticipantJoinedEvent ):
282- if event .participant is not None :
283- self .logger .info (f"Participant { event .participant .user_id } joined" )
284- self .participants [event .participant .session_id ] = event .participant
285-
286286 @self .events .subscribe
287287 async def on_stt_transcript_event_create_response (event : STTTranscriptEvent ):
288288 if _is_audio_llm (self .llm ):
@@ -342,7 +342,7 @@ async def on_stt_transcript_event_sync_conversation(event: STTTranscriptEvent):
342342
343343 @self .events .subscribe
344344 async def on_realtime_user_speech_transcription (
345- event : RealtimeUserSpeechTranscriptionEvent ,
345+ event : RealtimeUserSpeechTranscriptionEvent ,
346346 ):
347347 self .logger .info (f"🎤 [User transcript]: { event .text } " )
348348
@@ -362,7 +362,7 @@ async def on_realtime_user_speech_transcription(
362362
363363 @self .events .subscribe
364364 async def on_realtime_agent_speech_transcription (
365- event : RealtimeAgentSpeechTranscriptionEvent ,
365+ event : RealtimeAgentSpeechTranscriptionEvent ,
366366 ):
367367 self .logger .info (f"🎤 [Agent transcript]: { event .text } " )
368368
@@ -380,9 +380,6 @@ async def on_realtime_agent_speech_transcription(
380380 original = event ,
381381 )
382382
383-
384-
385-
386383 @self .llm .events .subscribe
387384 async def on_llm_response_sync_conversation (event : LLMResponseCompletedEvent ):
388385 self .logger .info (f"🤖 [LLM response]: { event .text } { event .item_id } " )
@@ -424,7 +421,6 @@ async def _handle_output_text_delta(event: LLMResponseChunkEvent):
424421
425422 logger .info ("AUDIO: SUB TO EVENTS DONE" )
426423
427-
428424 async def simple_response (
429425 self , text : str , participant : Optional [Participant ] = None
430426 ) -> None :
@@ -441,7 +437,7 @@ async def simple_response(
441437 span .set_attribute ("response.original" , response .original )
442438
443439 async def simple_audio_response (
444- self , pcm : PcmData , participant : Optional [Participant ] = None
440+ self , pcm : PcmData , participant : Optional [Participant ] = None
445441 ) -> None :
446442 """
447443 Makes it easy to subclass how the agent calls the LLM for processing audio
@@ -463,7 +459,9 @@ def subscribe(self, function):
463459 """
464460 return self .events .subscribe (function )
465461
466- async def join (self , call : Call , wait_for_participant = True ) -> "AgentSessionContextManager" :
462+ async def join (
463+ self , call : Call , wait_for_participant = True
464+ ) -> "AgentSessionContextManager" :
467465 # TODO: validation. join can only be called once
468466 self .logger .info ("joining call" )
469467 # run start on all subclasses
@@ -504,7 +502,8 @@ async def join(self, call: Call, wait_for_participant=True) -> "AgentSessionCont
504502
505503 with self .span ("edge.join" ):
506504 connection = await self .edge .join (self , call )
507- self .participants = connection ._connection .participants_state ._participant_by_prefix
505+ self .participants = connection .participants
506+
508507 except Exception :
509508 self .clear_call_logging_context ()
510509 raise
@@ -548,30 +547,23 @@ async def join(self, call: Call, wait_for_participant=True) -> "AgentSessionCont
548547
549548 async def wait_for_participant (self ):
550549 """wait for a participant other than the AI agent to join"""
551- # Check if a non-agent participant is already present
552- if self .call and self .participants :
553- for p in self .participants .values ():
554- if p .user_id != self .agent_user .id :
555- self .logger .info (f"Participant { p .user_id } already in call" )
556- return
557550
558- # If not, wait for one to join
559- participant_joined = asyncio . Event ()
551+ if self . participants is None :
552+ return
560553
561- @self .edge .events .subscribe
562- async def on_participant_joined (event : ParticipantJoinedEvent ):
563- if event .participant is not None :
564- is_agent = event .participant .user_id == self .agent_user .id
554+ participant_joined = asyncio .Event ()
565555
566- self .logger .info (f"Participant { event .participant .user_id } joined is_agent { is_agent } " )
567- if not is_agent :
556+ def on_participants (participants ):
557+ for p in participants :
558+ if p .user_id != self .agent_user .id :
568559 participant_joined .set ()
569560
570- # Wait for the event to be set
571- await participant_joined .wait ()
561+ subscription = self .participants .map (on_participants )
572562
573- # Clean up the subscription
574- self .edge .events .unsubscribe (on_participant_joined )
563+ try :
564+ await participant_joined .wait ()
565+ finally :
566+ subscription .unsubscribe ()
575567
576568 async def finish (self ):
577569 """Wait for the call to end gracefully.
@@ -624,7 +616,10 @@ async def _apply(self, function_name: str, *args, **kwargs):
624616 subclasses = [self .llm , self .stt , self .tts , self .turn_detection , self .edge ]
625617 subclasses .extend (self .processors )
626618 for subclass in subclasses :
627- if subclass is not None and getattr (subclass , function_name , None ) is not None :
619+ if (
620+ subclass is not None
621+ and getattr (subclass , function_name , None ) is not None
622+ ):
628623 func = getattr (subclass , function_name )
629624 if func is not None :
630625 await func (* args , ** kwargs )
@@ -642,13 +637,6 @@ def _end_tracing(self):
642637 def __aexit__ (self , exc_type , exc_val , exc_tb ):
643638 self ._end_tracing ()
644639
645-
646-
647-
648-
649-
650-
651-
652640 async def close (self ):
653641 """Clean up all connections and resources.
654642
@@ -888,7 +876,9 @@ async def _reply_to_audio_consumer(self) -> None:
888876 for processor in self .audio_processors :
889877 if processor is None :
890878 continue
891- await processor .process_audio (audio_bytes , participant .user_id )
879+ await processor .process_audio (
880+ audio_bytes , participant .user_id
881+ )
892882
893883 # when in Realtime mode call the Realtime directly (non-blocking)
894884 if _is_audio_llm (self .llm ):
@@ -937,48 +927,60 @@ async def _image_to_video_processors(self, track_id: str, track_type: int):
937927 for processor in self .image_processors :
938928 try :
939929 pass
940- #TODO: run this better
941- #await processor.process_image(
930+ # TODO: run this better
931+ # await processor.process_image(
942932 # img, track_info.participant.user_id, track_id=track_id, track_type=track_type
943- #)
933+ # )
944934 except Exception as e :
945935 self .logger .error (
946936 f"Error in image processor { type (processor ).__name__ } : { e } "
947937 )
948938
949- async def _on_track_removed (self , track_id : str , track_type : int , participant : Participant ):
939+ async def _on_track_removed (
940+ self , track_id : str , track_type : int , participant : Participant
941+ ):
950942 self ._active_video_tracks .pop (track_id )
951943 await self ._on_track_change (track_id )
952944
953945 async def _on_track_change (self , track_id : str ):
954946 # shared logic between track remove and added
955947 # Select a track. Prioritize screenshare over regular
956948 # This is the track without processing
957- non_processed_tracks = [t for t in self ._active_video_tracks .values () if not t .processor ]
958- source_track = sorted (non_processed_tracks , key = lambda t : t .priority , reverse = True )[0 ]
949+ non_processed_tracks = [
950+ t for t in self ._active_video_tracks .values () if not t .processor
951+ ]
952+ source_track = sorted (
953+ non_processed_tracks , key = lambda t : t .priority , reverse = True
954+ )[0 ]
959955 # assign the tracks that we last used so we can notify of changes...
960956 self ._active_source_track_id = source_track .id
961957
962958 await self ._track_to_video_processors (source_track )
963959
964- processed_track = sorted ([t for t in self ._active_video_tracks .values ()], key = lambda t : t .priority , reverse = True )[0 ]
960+ processed_track = sorted (
961+ [t for t in self ._active_video_tracks .values ()],
962+ key = lambda t : t .priority ,
963+ reverse = True ,
964+ )[0 ]
965965 self ._active_processed_track_id = processed_track .id
966966
967967 # See if we have a processed track. If so forward that to LLM
968968 # TODO: this should run in a loop and handle multiple forwarders
969- #self._image_to_video_processors()
969+ # self._image_to_video_processors()
970970
971971 # If Realtime provider supports video, switch to this new track
972972 if _is_video_llm (self .llm ):
973973 await self .llm .watch_video_track (
974974 processed_track .track , shared_forwarder = processed_track .forwarder
975975 )
976976
977- async def _on_track_added (self , track_id : str , track_type : int , participant : Participant ):
977+ async def _on_track_added (
978+ self , track_id : str , track_type : int , participant : Participant
979+ ):
978980 # We only process video tracks (camera video or screenshare)
979981 if track_type not in (
980- TrackType .TRACK_TYPE_VIDEO ,
981- TrackType .TRACK_TYPE_SCREEN_SHARE ,
982+ TrackType .TRACK_TYPE_VIDEO ,
983+ TrackType .TRACK_TYPE_SCREEN_SHARE ,
982984 ):
983985 return
984986
@@ -1001,14 +1003,12 @@ async def _on_track_added(self, track_id: str, track_type: int, participant: Par
10011003 processor = "" ,
10021004 track = track ,
10031005 participant = participant ,
1004- priority = 1 if track_type == TrackType .TRACK_TYPE_SCREEN_SHARE else 0 ,
1005- forwarder = forwarder
1006+ priority = 1 if track_type == TrackType .TRACK_TYPE_SCREEN_SHARE else 0 ,
1007+ forwarder = forwarder ,
10061008 )
10071009
10081010 await self ._on_track_change (track_id )
10091011
1010-
1011-
10121012 async def _on_turn_event (self , event : TurnStartedEvent | TurnEndedEvent ) -> None :
10131013 """Handle turn detection events."""
10141014 # Skip the turn event handling if the model doesn't require TTS or SST audio itself.
@@ -1066,7 +1066,6 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
10661066 # Clear the pending transcript for this speaker
10671067 self ._pending_user_transcripts [event .participant .user_id ] = ""
10681068
1069-
10701069 @property
10711070 def publish_audio (self ) -> bool :
10721071 """Whether the agent should publish an outbound audio track.
@@ -1241,7 +1240,7 @@ def _prepare_rtc(self):
12411240 track = self ._video_track ,
12421241 participant = None ,
12431242 priority = 2 ,
1244- forwarder = forwarder
1243+ forwarder = forwarder ,
12451244 )
12461245
12471246 self .logger .info ("🎥 Video track initialized from video publisher" )
0 commit comments