Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agents-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"pillow>=11.3.0",
"numpy>=1.24.0",
"mcp>=1.16.0",
"torchvision>=0.23.0",
"colorlog>=6.10.1",
]

[project.urls]
Expand Down
140 changes: 103 additions & 37 deletions agents-core/vision_agents/core/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType

from ..edge import sfu_events
from ..edge.events import AudioReceivedEvent, TrackAddedEvent, TrackRemovedEvent, CallEndedEvent
from ..edge.events import (
AudioReceivedEvent,
TrackAddedEvent,
TrackRemovedEvent,
CallEndedEvent,
)
from ..edge.types import Connection, Participant, PcmData, User, OutputAudioTrack
from ..events.manager import EventManager
from ..llm import events as llm_events
Expand All @@ -27,14 +32,19 @@
)
from ..llm.llm import LLM
from ..llm.realtime import Realtime
from ..logging_utils import CallContextToken, clear_call_context, set_call_context
from ..mcp import MCPBaseServer, MCPManager
from ..processors.base_processor import Processor, ProcessorType, filter_processors
from ..stt.events import STTTranscriptEvent, STTErrorEvent
from ..stt.stt import STT
from ..tts.tts import TTS
from ..tts.events import TTSAudioEvent
from ..turn_detection import TurnDetector, TurnStartedEvent, TurnEndedEvent
from ..utils.logging import (
CallContextToken,
clear_call_context,
set_call_context,
configure_default_logging,
)
from ..utils.video_forwarder import VideoForwarder
from ..utils.video_utils import ensure_even_dimensions
from ..vad import VAD
Expand Down Expand Up @@ -65,6 +75,16 @@ def _log_task_exception(task: asyncio.Task):
logger.exception("Error in background task")


class _AgentLoggerAdapter(logging.LoggerAdapter):
"""
A logger adapter to include the agent_id to the logs
"""

def process(self, msg: str, kwargs):
if self.extra:
return "[Agent: %s] | %s" % (self.extra["agent_id"], msg), kwargs
return super(_AgentLoggerAdapter, self).process(msg, kwargs)

# TODO: move me
@dataclass
class AgentOptions:
Expand Down Expand Up @@ -132,22 +152,29 @@ def __init__(
# MCP servers for external tool and resource access
mcp_servers: Optional[List[MCPBaseServer]] = None,
options: Optional[AgentOptions] = None,
tracer: Tracer = trace.get_tracer("agents"),
# Configure the default logging for the sdk here. Pass None to leave the config intact.
log_level: Optional[int] = logging.INFO,
):
if log_level is not None:
configure_default_logging(level=log_level)
if options is None:
options = default_agent_options()
else:
options = default_agent_options().update(options)
self.options = options

self.instructions = instructions
self.edge = edge
self.agent_user = agent_user
self._agent_user_initialized = False

# only needed in case we spin threads
self.tracer = tracer
self._root_span: Optional[Span] = None
self._root_ctx: Optional[Context] = None

self.logger = logging.getLogger(f"Agent[{self.agent_user.id}]")
self.logger = _AgentLoggerAdapter(logger, {"agent_id": self.agent_user.id})

self.events = EventManager()
self.events.register_events_from_module(getstream.models, "call.")
Expand Down Expand Up @@ -183,7 +210,7 @@ def __init__(
# Merge plugin events BEFORE subscribing to any events
for plugin in [stt, tts, turn_detection, vad, llm, edge]:
if plugin and hasattr(plugin, "events"):
self.logger.info(f"Registered plugin {plugin}")
self.logger.debug(f"Register events from plugin {plugin}")
self.events.merge(plugin.events)

self.llm._attach_agent(self)
Expand Down Expand Up @@ -242,8 +269,8 @@ async def simple_response(
"""
Overwrite simple_response if you want to change how the Agent class calls the LLM
"""
logger.info("asking LLM to reply to %s", text)
with self.span("simple_response") as span:
self.logger.info('🤖 Asking LLM to reply to "%s"', text)
with self.tracer.start_as_current_span("simple_response") as span:
response = await self.llm.simple_response(
text=text, processors=self.processors, participant=participant
)
Expand Down Expand Up @@ -504,6 +531,13 @@ async def finish(self):
Subscribes to the edge transport's `call_ended` event and awaits it. If
no connection is active, returns immediately.
"""
# If connection is None or already closed, return immediately
if not self._connection:
self.logger.info(
"🔚 Agent connection is already closed, finishing immediately"
)
return


with self.span("agent.finish"):
# If connection is None or already closed, return immediately
Expand Down Expand Up @@ -631,7 +665,7 @@ async def create_user(self) -> None:
return None

def _on_vad_audio(self, event: VADAudioEvent):
self.logger.info(f"Vad audio event {self._truncate_for_logging(event)}")
self.logger.debug(f"Vad audio event {self._truncate_for_logging(event)}")

def _on_rtc_reconnect(self):
# update the code to listen?
Expand Down Expand Up @@ -677,7 +711,7 @@ async def _on_agent_say(self, event: events.AgentSayEvent):
)
)

self.logger.info(f"Agent said: {event.text}")
self.logger.info(f"🤖 Agent said: {event.text}")
else:
self.logger.warning("No TTS available, cannot synthesize speech")

Expand Down Expand Up @@ -764,14 +798,18 @@ async def on_track(event: TrackAddedEvent):
# If track is already being processed, just switch to it
if track_id in self._active_video_tracks:
track_type_name = TrackType.Name(track_type)
self.logger.info(f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it")

self.logger.info(
f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it"
)

if self.realtime_mode and isinstance(self.llm, Realtime):
# Get the existing forwarder and switch to this track
_, _, forwarder = self._active_video_tracks[track_id]
track = self.edge.add_track_subscriber(track_id)
if track and forwarder:
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
await self.llm._watch_video_track(
track, shared_forwarder=forwarder
)
self._current_video_track_id = track_id
return

Expand All @@ -796,10 +834,16 @@ async def on_track_removed(event: TrackRemovedEvent):

# Clean up track metadata
self._active_video_tracks.pop(track_id, None)

# If this was the active track, switch to any other available track
if track_id == self._current_video_track_id and self.realtime_mode and isinstance(self.llm, Realtime):
self.logger.info("🎥 Active video track removed, switching to next available")
if (
track_id == self._current_video_track_id
and self.realtime_mode
and isinstance(self.llm, Realtime)
):
self.logger.info(
"🎥 Active video track removed, switching to next available"
)
self._current_video_track_id = None
await self._switch_to_next_available_track()

Expand Down Expand Up @@ -839,16 +883,23 @@ async def _switch_to_next_available_track(self) -> None:
self.logger.info("🎥 No video tracks available")
self._current_video_track_id = None
return

# Just pick the first available video track
for track_id, (track_type, participant, forwarder) in self._active_video_tracks.items():
for track_id, (
track_type,
participant,
forwarder,
) in self._active_video_tracks.items():
# Only consider video tracks (camera or screenshare)
if track_type not in (TrackType.TRACK_TYPE_VIDEO, TrackType.TRACK_TYPE_SCREEN_SHARE):
if track_type not in (
TrackType.TRACK_TYPE_VIDEO,
TrackType.TRACK_TYPE_SCREEN_SHARE,
):
continue

track_type_name = TrackType.Name(track_type)
self.logger.info(f"🎥 Switching to track: {track_type_name} ({track_id})")

# Get the track and forwarder
track = self.edge.add_track_subscriber(track_id)
if track and forwarder and isinstance(self.llm, Realtime):
Expand All @@ -858,16 +909,19 @@ async def _switch_to_next_available_track(self) -> None:
return
else:
self.logger.error(f"Failed to switch to track {track_id}")

self.logger.warning("🎥 No suitable video tracks found")

async def _process_track(self, track_id: str, track_type: int, participant):
raw_forwarder = None
processed_forwarder = None

try:
# we only process video tracks (camera video or screenshare)
if track_type not in (TrackType.TRACK_TYPE_VIDEO, TrackType.TRACK_TYPE_SCREEN_SHARE):
if track_type not in (
TrackType.TRACK_TYPE_VIDEO,
TrackType.TRACK_TYPE_SCREEN_SHARE,
):
return

# subscribe to the video track
Expand All @@ -877,14 +931,16 @@ async def _process_track(self, track_id: str, track_type: int, participant):
return

# Wrap screenshare tracks to ensure even dimensions for H.264 encoding
if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE:
if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE:

class _EvenDimensionsTrack(VideoStreamTrack):
def __init__(self, src):
def __init__(self, src):
super().__init__()
self.src = src
async def recv(self):

async def recv(self):
return ensure_even_dimensions(await self.src.recv())

track = _EvenDimensionsTrack(track) # type: ignore[arg-type]

# Create a SHARED VideoForwarder for the RAW incoming track
Expand All @@ -896,17 +952,21 @@ async def recv(self):
name=f"raw_video_forwarder_{track_id}",
)
await raw_forwarder.start()
self.logger.info("🎥 Created raw VideoForwarder for track %s", track_id)
self.logger.debug("🎥 Created raw VideoForwarder for track %s", track_id)

# Track forwarders for cleanup
self._video_forwarders.append(raw_forwarder)

# Store track metadata
self._active_video_tracks[track_id] = (track_type, participant, raw_forwarder)
self._active_video_tracks[track_id] = (
track_type,
participant,
raw_forwarder,
)

# If Realtime provider supports video, switch to this new track
track_type_name = TrackType.Name(track_type)

if self.realtime_mode:
if self._video_track:
# We have a video publisher (e.g., YOLO processor)
Expand All @@ -931,7 +991,9 @@ async def recv(self):
self._current_video_track_id = track_id
else:
# No video publisher, send raw frames - switch to this new track
self.logger.info(f"🎥 Switching to {track_type_name} track: {track_id}")
self.logger.info(
f"🎥 Switching to {track_type_name} track: {track_id}"
)
if isinstance(self.llm, Realtime):
await self.llm._watch_video_track(
track, shared_forwarder=raw_forwarder
Expand Down Expand Up @@ -978,7 +1040,9 @@ async def recv(self):

for processor in self.image_processors:
try:
await processor.process_image(img, participant.user_id)
await processor.process_image(
img, participant.user_id
)
except Exception as e:
self.logger.error(
f"Error in image processor {type(processor).__name__}: {e}"
Expand All @@ -999,26 +1063,28 @@ async def recv(self):
except asyncio.CancelledError:
# Task was cancelled (e.g., track removed)
# Clean up forwarders that were created for this track
self.logger.debug(f"🎥 Cleaning up forwarders for cancelled track {track_id}")

self.logger.debug(
f"🎥 Cleaning up forwarders for cancelled track {track_id}"
)

# Stop and remove the raw forwarder if it was created
if raw_forwarder is not None and hasattr(self, '_video_forwarders'):
if raw_forwarder is not None and hasattr(self, "_video_forwarders"):
if raw_forwarder in self._video_forwarders:
try:
await raw_forwarder.stop()
self._video_forwarders.remove(raw_forwarder)
except Exception as e:
self.logger.error(f"Error stopping raw forwarder: {e}")

# Stop and remove processed forwarder if it was created
if processed_forwarder is not None and hasattr(self, '_video_forwarders'):
if processed_forwarder is not None and hasattr(self, "_video_forwarders"):
if processed_forwarder in self._video_forwarders:
try:
await processed_forwarder.stop()
self._video_forwarders.remove(processed_forwarder)
except Exception as e:
self.logger.error(f"Error stopping processed forwarder: {e}")

return

async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None:
Expand Down
10 changes: 5 additions & 5 deletions agents-core/vision_agents/core/events/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ async def handle_audio_events(event: VADSpeechStartEvent | VADSpeechEndEvent):
subscribed = True
self._handlers.setdefault(event_type, []).append(function)
module_name = getattr(function, "__module__", "unknown")
logger.info(
logger.debug(
f"Handler {function.__name__} from {module_name} registered for event {event_type}"
)
elif not self._ignore_unknown_events:
Expand All @@ -362,7 +362,7 @@ async def handle_audio_events(event: VADSpeechStartEvent | VADSpeechEndEvent):
)
else:
module_name = getattr(function, "__module__", "unknown")
logger.info(
logger.debug(
f"Event {sub_event} - {event_type} is not registered – skipping handler {function.__name__} from {module_name}."
)
return function
Expand Down Expand Up @@ -400,7 +400,7 @@ def _prepare_event(self, event):
else:
# No matching event class found
if self._ignore_unknown_events:
logger.info(f"Protobuf event not registered: {proto_type}")
logger.debug(f"Protobuf event not registered: {proto_type}")
return
else:
raise RuntimeError(f"Protobuf event not registered: {proto_type}")
Expand All @@ -410,7 +410,7 @@ def _prepare_event(self, event):
# logger.info(f"Received event {_truncate_event_for_logging(event)}")
return event
elif self._ignore_unknown_events:
logger.info(f"Event not registered {_truncate_event_for_logging(event)}")
logger.debug(f"Event not registered {_truncate_event_for_logging(event)}")
else:
raise RuntimeError(f"Event not registered {event}")

Expand Down Expand Up @@ -507,7 +507,7 @@ async def _process_events_loop(self):
await self._process_single_event(event)
except asyncio.CancelledError as exc:
cancelled_exc = exc
logger.info(
logger.debug(
f"Event processing task was cancelled, processing remaining events, {len(self._queue)}"
)
await self._process_single_event(event)
Expand Down
Loading
Loading