Skip to content

Commit 5bcffa3

Browse files
Merge pull request #119 from GetStream/fix-screensharing
Fix screensharing
2 parents 406673c + bfe888f commit 5bcffa3

File tree

6 files changed

+387
-142
lines changed

6 files changed

+387
-142
lines changed

agents-core/vision_agents/core/agents/agents.py

Lines changed: 223 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType
1515
from ..edge import sfu_events
16-
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, CallEndedEvent
17-
from ..edge.types import Connection, Participant, PcmData, User, OutputAudioTrack
16+
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, TrackRemovedEvent, CallEndedEvent
17+
from ..edge.types import Connection, Participant, PcmData, User
1818
from ..events.manager import EventManager
1919
from ..llm import events as llm_events
2020
from ..llm.events import (
@@ -33,6 +33,8 @@
3333
from ..tts.tts import TTS
3434
from ..tts.events import TTSAudioEvent
3535
from ..turn_detection import TurnDetector, TurnStartedEvent, TurnEndedEvent
36+
from ..utils.video_forwarder import VideoForwarder
37+
from ..utils.video_utils import ensure_even_dimensions
3638
from ..vad import VAD
3739
from ..vad.events import VADAudioEvent
3840
from . import events
@@ -159,6 +161,10 @@ def __init__(
159161
self._interval_task = None
160162
self._callback_executed = False
161163
self._track_tasks: Dict[str, asyncio.Task] = {}
164+
# Track metadata: track_id -> (track_type, participant, forwarder)
165+
self._active_video_tracks: Dict[str, tuple[int, Any, Any]] = {}
166+
self._video_forwarders: List[VideoForwarder] = []
167+
self._current_video_track_id: Optional[str] = None
162168
self._connection: Optional[Connection] = None
163169
self._audio_track: Optional[OutputAudioTrack] = None
164170
self._video_track: Optional[VideoStreamTrack] = None
@@ -669,10 +675,48 @@ async def on_track(event: TrackAddedEvent):
669675
if not track_id or not track_type:
670676
return
671677

678+
# If track is already being processed, just switch to it
679+
if track_id in self._active_video_tracks:
680+
track_type_name = TrackType.Name(track_type)
681+
self.logger.info(f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it")
682+
683+
if self.realtime_mode and isinstance(self.llm, Realtime):
684+
# Get the existing forwarder and switch to this track
685+
_, _, forwarder = self._active_video_tracks[track_id]
686+
track = self.edge.add_track_subscriber(track_id)
687+
if track and forwarder:
688+
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
689+
self._current_video_track_id = track_id
690+
return
691+
672692
task = asyncio.create_task(self._process_track(track_id, track_type, user))
673693
self._track_tasks[track_id] = task
674694
task.add_done_callback(_log_task_exception)
675695

696+
@self.edge.events.subscribe
697+
async def on_track_removed(event: TrackRemovedEvent):
698+
track_id = event.track_id
699+
track_type = event.track_type
700+
if not track_id:
701+
return
702+
703+
track_type_name = TrackType.Name(track_type) if track_type else "unknown"
704+
self.logger.info(f"🎥 Track removed: {track_type_name} ({track_id})")
705+
706+
# Cancel the processing task for this track
707+
if track_id in self._track_tasks:
708+
self._track_tasks[track_id].cancel()
709+
self._track_tasks.pop(track_id)
710+
711+
# Clean up track metadata
712+
self._active_video_tracks.pop(track_id, None)
713+
714+
# If this was the active track, switch to any other available track
715+
if track_id == self._current_video_track_id and self.realtime_mode and isinstance(self.llm, Realtime):
716+
self.logger.info("🎥 Active video track removed, switching to next available")
717+
self._current_video_track_id = None
718+
await self._switch_to_next_available_track()
719+
676720
async def _reply_to_audio(
677721
self, pcm_data: PcmData, participant: Participant
678722
) -> None:
@@ -701,125 +745,193 @@ async def _reply_to_audio(
701745
self.logger.debug(f"🎵 Processing audio from {participant}")
702746
await self.stt.process_audio(pcm_data, participant)
703747

704-
async def _process_track(self, track_id: str, track_type: int, participant):
705-
# TODO: handle CancelledError
706-
# we only process video tracks
707-
if track_type != TrackType.TRACK_TYPE_VIDEO:
708-
return
709-
710-
# subscribe to the video track
711-
track = self.edge.add_track_subscriber(track_id)
712-
if not track:
713-
self.logger.error(f"Failed to subscribe to {track_id}")
748+
async def _switch_to_next_available_track(self) -> None:
749+
"""Switch to any available video track."""
750+
if not self._active_video_tracks:
751+
self.logger.info("🎥 No video tracks available")
752+
self._current_video_track_id = None
714753
return
754+
755+
# Just pick the first available video track
756+
for track_id, (track_type, participant, forwarder) in self._active_video_tracks.items():
757+
# Only consider video tracks (camera or screenshare)
758+
if track_type not in (TrackType.TRACK_TYPE_VIDEO, TrackType.TRACK_TYPE_SCREEN_SHARE):
759+
continue
760+
761+
track_type_name = TrackType.Name(track_type)
762+
self.logger.info(f"🎥 Switching to track: {track_type_name} ({track_id})")
763+
764+
# Get the track and forwarder
765+
track = self.edge.add_track_subscriber(track_id)
766+
if track and forwarder and isinstance(self.llm, Realtime):
767+
# Send to Realtime provider
768+
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
769+
self._current_video_track_id = track_id
770+
return
771+
else:
772+
self.logger.error(f"Failed to switch to track {track_id}")
773+
774+
self.logger.warning("🎥 No suitable video tracks found")
715775

716-
# Import VideoForwarder
717-
from ..utils.video_forwarder import VideoForwarder
718-
719-
# Create a SHARED VideoForwarder for the RAW incoming track
720-
# This prevents multiple recv() calls competing on the same track
721-
raw_forwarder = VideoForwarder(
722-
track, # type: ignore[arg-type]
723-
max_buffer=30,
724-
fps=30, # Max FPS for the producer (individual consumers can throttle down)
725-
name=f"raw_video_forwarder_{track_id}",
726-
)
727-
await raw_forwarder.start()
728-
self.logger.info("🎥 Created raw VideoForwarder for track %s", track_id)
729-
730-
# Track forwarders for cleanup
731-
if not hasattr(self, "_video_forwarders"):
732-
self._video_forwarders = []
733-
self._video_forwarders.append(raw_forwarder)
776+
async def _process_track(self, track_id: str, track_type: int, participant):
777+
raw_forwarder = None
778+
processed_forwarder = None
779+
780+
try:
781+
# we only process video tracks (camera video or screenshare)
782+
if track_type not in (TrackType.TRACK_TYPE_VIDEO, TrackType.TRACK_TYPE_SCREEN_SHARE):
783+
return
734784

735-
# If Realtime provider supports video, determine which track to send
736-
if self.realtime_mode:
737-
if self._video_track:
738-
# We have a video publisher (e.g., YOLO processor)
739-
# Create a separate forwarder for the PROCESSED video track
740-
self.logger.info(
741-
"🎥 Forwarding PROCESSED video frames to Realtime provider"
742-
)
743-
processed_forwarder = VideoForwarder(
744-
self._video_track, # type: ignore[arg-type]
745-
max_buffer=30,
746-
fps=30,
747-
name=f"processed_video_forwarder_{track_id}",
748-
)
749-
await processed_forwarder.start()
750-
self._video_forwarders.append(processed_forwarder)
785+
# subscribe to the video track
786+
track = self.edge.add_track_subscriber(track_id)
787+
if not track:
788+
self.logger.error(f"Failed to subscribe to {track_id}")
789+
return
751790

752-
if isinstance(self.llm, Realtime):
753-
# Send PROCESSED frames with the processed forwarder
754-
await self.llm._watch_video_track(
755-
self._video_track, shared_forwarder=processed_forwarder
791+
# Wrap screenshare tracks to ensure even dimensions for H.264 encoding
792+
if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE:
793+
class _EvenDimensionsTrack(VideoStreamTrack):
794+
def __init__(self, src):
795+
super().__init__()
796+
self.src = src
797+
async def recv(self):
798+
return ensure_even_dimensions(await self.src.recv())
799+
800+
track = _EvenDimensionsTrack(track) # type: ignore[arg-type]
801+
802+
# Create a SHARED VideoForwarder for the RAW incoming track
803+
# This prevents multiple recv() calls competing on the same track
804+
raw_forwarder = VideoForwarder(
805+
track, # type: ignore[arg-type]
806+
max_buffer=30,
807+
fps=30, # Max FPS for the producer (individual consumers can throttle down)
808+
name=f"raw_video_forwarder_{track_id}",
809+
)
810+
await raw_forwarder.start()
811+
self.logger.info("🎥 Created raw VideoForwarder for track %s", track_id)
812+
813+
# Track forwarders for cleanup
814+
self._video_forwarders.append(raw_forwarder)
815+
816+
# Store track metadata
817+
self._active_video_tracks[track_id] = (track_type, participant, raw_forwarder)
818+
819+
# If Realtime provider supports video, switch to this new track
820+
track_type_name = TrackType.Name(track_type)
821+
822+
if self.realtime_mode:
823+
if self._video_track:
824+
# We have a video publisher (e.g., YOLO processor)
825+
# Create a separate forwarder for the PROCESSED video track
826+
self.logger.info(
827+
"🎥 Forwarding PROCESSED video frames to Realtime provider"
756828
)
757-
else:
758-
# No video publisher, send raw frames
759-
self.logger.info("🎥 Forwarding RAW video frames to Realtime provider")
760-
if isinstance(self.llm, Realtime):
761-
await self.llm._watch_video_track(
762-
track, shared_forwarder=raw_forwarder
829+
processed_forwarder = VideoForwarder(
830+
self._video_track, # type: ignore[arg-type]
831+
max_buffer=30,
832+
fps=30,
833+
name=f"processed_video_forwarder_{track_id}",
834+
)
835+
await processed_forwarder.start()
836+
self._video_forwarders.append(processed_forwarder)
837+
838+
if isinstance(self.llm, Realtime):
839+
# Send PROCESSED frames with the processed forwarder
840+
await self.llm._watch_video_track(
841+
self._video_track, shared_forwarder=processed_forwarder
842+
)
843+
self._current_video_track_id = track_id
844+
else:
845+
# No video publisher, send raw frames - switch to this new track
846+
self.logger.info(f"🎥 Switching to {track_type_name} track: {track_id}")
847+
if isinstance(self.llm, Realtime):
848+
await self.llm._watch_video_track(
849+
track, shared_forwarder=raw_forwarder
850+
)
851+
self._current_video_track_id = track_id
852+
853+
has_image_processors = len(self.image_processors) > 0
854+
855+
# video processors - pass the raw forwarder (they process incoming frames)
856+
for processor in self.video_processors:
857+
try:
858+
await processor.process_video(
859+
track, participant.user_id, shared_forwarder=raw_forwarder
860+
)
861+
except Exception as e:
862+
self.logger.error(
863+
f"Error in video processor {type(processor).__name__}: {e}"
763864
)
764865

765-
hasImageProcessers = len(self.image_processors) > 0
766-
767-
# video processors - pass the raw forwarder (they process incoming frames)
768-
for processor in self.video_processors:
769-
try:
770-
await processor.process_video(
771-
track, participant.user_id, shared_forwarder=raw_forwarder
772-
)
773-
except Exception as e:
774-
self.logger.error(
775-
f"Error in video processor {type(processor).__name__}: {e}"
866+
# Use raw forwarder for image processors - only if there are image processors
867+
if not has_image_processors:
868+
# No image processors, just keep the connection alive
869+
self.logger.info(
870+
"No image processors, video processing handled by video processors only"
776871
)
872+
return
777873

778-
# Use raw forwarder for image processors - only if there are image processors
779-
if not hasImageProcessers:
780-
# No image processors, just keep the connection alive
781-
self.logger.info(
782-
"No image processors, video processing handled by video processors only"
783-
)
784-
return
785-
786-
# Initialize error tracking counters
787-
timeout_errors = 0
788-
consecutive_errors = 0
789-
790-
while True:
791-
try:
792-
# Use the raw forwarder instead of competing for track.recv()
793-
video_frame = await raw_forwarder.next_frame(timeout=2.0)
794-
795-
if video_frame:
796-
# Reset error counts on successful frame processing
797-
timeout_errors = 0
798-
consecutive_errors = 0
799-
800-
if hasImageProcessers:
801-
img = video_frame.to_image()
802-
803-
for processor in self.image_processors:
804-
try:
805-
await processor.process_image(img, participant.user_id)
806-
except Exception as e:
807-
self.logger.error(
808-
f"Error in image processor {type(processor).__name__}: {e}"
809-
)
810-
811-
else:
812-
self.logger.warning("🎥VDP: Received empty frame")
813-
consecutive_errors += 1
874+
# Initialize error tracking counters
875+
timeout_errors = 0
876+
consecutive_errors = 0
814877

815-
except asyncio.TimeoutError:
816-
# Exponential backoff for timeout errors
817-
timeout_errors += 1
818-
backoff_delay = min(2.0 ** min(timeout_errors, 5), 30.0)
819-
self.logger.debug(
820-
f"🎥VDP: Applying backoff delay: {backoff_delay:.1f}s"
821-
)
822-
await asyncio.sleep(backoff_delay)
878+
while True:
879+
try:
880+
# Use the raw forwarder instead of competing for track.recv()
881+
video_frame = await raw_forwarder.next_frame(timeout=2.0)
882+
883+
if video_frame:
884+
# Reset error counts on successful frame processing
885+
timeout_errors = 0
886+
consecutive_errors = 0
887+
888+
if has_image_processors:
889+
img = video_frame.to_image()
890+
891+
for processor in self.image_processors:
892+
try:
893+
await processor.process_image(img, participant.user_id)
894+
except Exception as e:
895+
self.logger.error(
896+
f"Error in image processor {type(processor).__name__}: {e}"
897+
)
898+
899+
else:
900+
self.logger.warning("🎥VDP: Received empty frame")
901+
consecutive_errors += 1
902+
903+
except asyncio.TimeoutError:
904+
# Exponential backoff for timeout errors
905+
timeout_errors += 1
906+
backoff_delay = min(2.0 ** min(timeout_errors, 5), 30.0)
907+
self.logger.debug(
908+
f"🎥VDP: Applying backoff delay: {backoff_delay:.1f}s"
909+
)
910+
await asyncio.sleep(backoff_delay)
911+
except asyncio.CancelledError:
912+
# Task was cancelled (e.g., track removed)
913+
# Clean up forwarders that were created for this track
914+
self.logger.debug(f"🎥 Cleaning up forwarders for cancelled track {track_id}")
915+
916+
# Stop and remove the raw forwarder if it was created
917+
if raw_forwarder is not None and hasattr(self, '_video_forwarders'):
918+
if raw_forwarder in self._video_forwarders:
919+
try:
920+
await raw_forwarder.stop()
921+
self._video_forwarders.remove(raw_forwarder)
922+
except Exception as e:
923+
self.logger.error(f"Error stopping raw forwarder: {e}")
924+
925+
# Stop and remove processed forwarder if it was created
926+
if processed_forwarder is not None and hasattr(self, '_video_forwarders'):
927+
if processed_forwarder in self._video_forwarders:
928+
try:
929+
await processed_forwarder.stop()
930+
self._video_forwarders.remove(processed_forwarder)
931+
except Exception as e:
932+
self.logger.error(f"Error stopping processed forwarder: {e}")
933+
934+
return
823935

824936
async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None:
825937
"""Handle turn detection events."""

0 commit comments

Comments
 (0)