diff --git a/livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/experimental/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/experimental/realtime/realtime_model.py index 9b83466159..52167b1443 100644 --- a/livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/experimental/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/experimental/realtime/realtime_model.py @@ -40,7 +40,6 @@ utils, ) from livekit.agents.metrics import RealtimeModelMetrics -from livekit.agents.metrics.base import Metadata from livekit.agents.types import NOT_GIVEN, NotGivenOr from livekit.agents.utils import is_given from livekit.plugins.aws.experimental.realtime.turn_tracker import _TurnTracker @@ -168,6 +167,7 @@ class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc] def __init__(self) -> None: self.session = boto3.Session() # type: ignore[attr-defined] + self._cached_credentials: AWSCredentialsIdentity | None = None async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity: """Asynchronously resolve AWS credentials. @@ -183,6 +183,10 @@ async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity: Raises: ValueError: If no credentials could be found by boto3. """ + # Return cached credentials if available + if self._cached_credentials is not None: + return self._cached_credentials + try: logger.debug("Attempting to load AWS credentials") credentials = self.session.get_credentials() @@ -195,13 +199,14 @@ async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity: f"AWS credentials loaded successfully. AWS_ACCESS_KEY_ID: {creds.access_key[:4]}***" ) - identity = AWSCredentialsIdentity( + # Cache the credentials for future use + self._cached_credentials = AWSCredentialsIdentity( access_key_id=creds.access_key, secret_access_key=creds.secret_key, session_token=creds.token if creds.token else None, expiration=None, ) - return identity + return self._cached_credentials except Exception as e: logger.error(f"Failed to load AWS credentials: {str(e)}") raise ValueError(f"Failed to load AWS credentials: {str(e)}") # noqa: B904 @@ -241,7 +246,6 @@ def __init__( user_transcription=True, auto_tool_reply_generation=True, audio_output=True, - manual_function_calls=False, ) ) self.model_id = "amazon.nova-sonic-v1:0" @@ -260,14 +264,6 @@ def __init__( ) self._sessions = weakref.WeakSet[RealtimeSession]() - @property - def model(self) -> str: - return self.model_id - - @property - def provider(self) -> str: - return "Amazon" - def session(self) -> RealtimeSession: """Return a new RealtimeSession bound to this model instance.""" sess = RealtimeSession(self) @@ -364,8 +360,8 @@ def _initialize_client(self) -> None: endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com", region=self._realtime_model._opts.region, aws_credentials_identity_resolver=Boto3CredentialsResolver(), - http_auth_scheme_resolver=HTTPAuthSchemeResolver(), - http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()}, + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, user_agent_extra="x-client-framework:livekit-plugins-aws[realtime]", ) self._bedrock_client = BedrockRuntimeClient(config=config) @@ -566,7 +562,6 @@ def emit_generation_event(self) -> None: message_stream=self._current_generation.message_ch, function_stream=self._current_generation.function_ch, user_initiated=False, - response_id=self._current_generation.response_id, ) self.emit("generation_created", generation_ev) @@ -603,16 +598,11 @@ def _create_response_generation(self) -> None: text_ch=utils.aio.Chan(), audio_ch=utils.aio.Chan(), ) - msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]() - msg_modalities.set_result( - ["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"] - ) self._current_generation.message_ch.send_nowait( llm.MessageGeneration( message_id=msg_gen.message_id, text_stream=msg_gen.text_ch, audio_stream=msg_gen.audio_ch, - modalities=msg_modalities, ) ) self._current_generation.messages[self._current_generation.response_id] = msg_gen @@ -777,16 +767,11 @@ async def _handle_tool_output_content_event(self, event_data: dict) -> None: audio_ch=utils.aio.Chan(), ) self._current_generation.messages[self._current_generation.response_id] = msg_gen - msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]() - msg_modalities.set_result( - ["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"] - ) self._current_generation.message_ch.send_nowait( llm.MessageGeneration( message_id=msg_gen.message_id, text_stream=msg_gen.text_ch, audio_stream=msg_gen.audio_ch, - modalities=msg_modalities, ) ) self.emit_generation_event() @@ -875,7 +860,7 @@ async def _handle_usage_event(self, event_data: dict) -> None: output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"] # Q: should we be counting per turn or utterance? metrics = RealtimeModelMetrics( - label=self._realtime_model.label, + label=self._realtime_model._label, # TODO: pass in the correct request_id request_id=event_data["event"]["usageEvent"]["completionId"], timestamp=time.monotonic(), @@ -902,9 +887,6 @@ async def _handle_usage_event(self, event_data: dict) -> None: audio_tokens=output_tokens["speechTokens"], image_tokens=0, ), - metadata=Metadata( - model_name=self._realtime_model.model, model_provider=self._realtime_model.provider - ), ) self.emit("metrics_collected", metrics) @@ -1259,12 +1241,7 @@ def interrupt(self) -> None: logger.warning("interrupt is not supported by Nova Sonic's Realtime API") def truncate( - self, - *, - message_id: str, - modalities: list[Literal["text", "audio"]], - audio_end_ms: int, - audio_transcript: NotGivenOr[str] = NOT_GIVEN, + self, *, message_id: str, audio_end_ms: int, audio_transcript: NotGivenOr[str] = NOT_GIVEN ) -> None: logger.warning("truncate is not supported by Nova Sonic's Realtime API")