Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
utils,
)
from livekit.agents.metrics import RealtimeModelMetrics
from livekit.agents.metrics.base import Metadata
from livekit.agents.types import NOT_GIVEN, NotGivenOr
from livekit.agents.utils import is_given
from livekit.plugins.aws.experimental.realtime.turn_tracker import _TurnTracker
Expand Down Expand Up @@ -168,6 +167,7 @@

def __init__(self) -> None:
self.session = boto3.Session() # type: ignore[attr-defined]
self._cached_credentials: AWSCredentialsIdentity | None = None

async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity:
"""Asynchronously resolve AWS credentials.
Expand All @@ -183,6 +183,10 @@
Raises:
ValueError: If no credentials could be found by boto3.
"""
# Return cached credentials if available
if self._cached_credentials is not None:
return self._cached_credentials

Check failure on line 189 in livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/experimental/realtime/realtime_model.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/experimental/realtime/realtime_model.py:189:1: W293 Blank line contains whitespace
try:
logger.debug("Attempting to load AWS credentials")
credentials = self.session.get_credentials()
Expand All @@ -195,13 +199,14 @@
f"AWS credentials loaded successfully. AWS_ACCESS_KEY_ID: {creds.access_key[:4]}***"
)

identity = AWSCredentialsIdentity(
# Cache the credentials for future use
self._cached_credentials = AWSCredentialsIdentity(
access_key_id=creds.access_key,
secret_access_key=creds.secret_key,
session_token=creds.token if creds.token else None,
expiration=None,
)
return identity
return self._cached_credentials
except Exception as e:
logger.error(f"Failed to load AWS credentials: {str(e)}")
raise ValueError(f"Failed to load AWS credentials: {str(e)}") # noqa: B904
Expand Down Expand Up @@ -241,7 +246,6 @@
user_transcription=True,
auto_tool_reply_generation=True,
audio_output=True,
manual_function_calls=False,
)
)
self.model_id = "amazon.nova-sonic-v1:0"
Expand All @@ -260,14 +264,6 @@
)
self._sessions = weakref.WeakSet[RealtimeSession]()

@property
def model(self) -> str:
return self.model_id

@property
def provider(self) -> str:
return "Amazon"

def session(self) -> RealtimeSession:
"""Return a new RealtimeSession bound to this model instance."""
sess = RealtimeSession(self)
Expand Down Expand Up @@ -364,8 +360,8 @@
endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
region=self._realtime_model._opts.region,
aws_credentials_identity_resolver=Boto3CredentialsResolver(),
http_auth_scheme_resolver=HTTPAuthSchemeResolver(),
http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()},
auth_scheme_resolver=HTTPAuthSchemeResolver(),
auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")},
user_agent_extra="x-client-framework:livekit-plugins-aws[realtime]",
)
self._bedrock_client = BedrockRuntimeClient(config=config)
Expand Down Expand Up @@ -566,7 +562,6 @@
message_stream=self._current_generation.message_ch,
function_stream=self._current_generation.function_ch,
user_initiated=False,
response_id=self._current_generation.response_id,
)
self.emit("generation_created", generation_ev)

Expand Down Expand Up @@ -603,16 +598,11 @@
text_ch=utils.aio.Chan(),
audio_ch=utils.aio.Chan(),
)
msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]()
msg_modalities.set_result(
["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
)
self._current_generation.message_ch.send_nowait(
llm.MessageGeneration(
message_id=msg_gen.message_id,
text_stream=msg_gen.text_ch,
audio_stream=msg_gen.audio_ch,
modalities=msg_modalities,
)
)
self._current_generation.messages[self._current_generation.response_id] = msg_gen
Expand Down Expand Up @@ -777,16 +767,11 @@
audio_ch=utils.aio.Chan(),
)
self._current_generation.messages[self._current_generation.response_id] = msg_gen
msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]()
msg_modalities.set_result(
["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
)
self._current_generation.message_ch.send_nowait(
llm.MessageGeneration(
message_id=msg_gen.message_id,
text_stream=msg_gen.text_ch,
audio_stream=msg_gen.audio_ch,
modalities=msg_modalities,
)
)
self.emit_generation_event()
Expand Down Expand Up @@ -875,7 +860,7 @@
output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"]
# Q: should we be counting per turn or utterance?
metrics = RealtimeModelMetrics(
label=self._realtime_model.label,
label=self._realtime_model._label,
# TODO: pass in the correct request_id
request_id=event_data["event"]["usageEvent"]["completionId"],
timestamp=time.monotonic(),
Expand All @@ -902,9 +887,6 @@
audio_tokens=output_tokens["speechTokens"],
image_tokens=0,
),
metadata=Metadata(
model_name=self._realtime_model.model, model_provider=self._realtime_model.provider
),
)
self.emit("metrics_collected", metrics)

Expand Down Expand Up @@ -1259,12 +1241,7 @@
logger.warning("interrupt is not supported by Nova Sonic's Realtime API")

def truncate(
self,
*,
message_id: str,
modalities: list[Literal["text", "audio"]],
audio_end_ms: int,
audio_transcript: NotGivenOr[str] = NOT_GIVEN,
self, *, message_id: str, audio_end_ms: int, audio_transcript: NotGivenOr[str] = NOT_GIVEN
) -> None:
logger.warning("truncate is not supported by Nova Sonic's Realtime API")

Expand Down
Loading