Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/llm/_provider_format/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass

from livekit.agents import llm
from livekit.agents.log import logger

from .utils import group_tool_calls

Expand Down Expand Up @@ -33,6 +34,9 @@ def to_chat_ctx(
role = "assistant"
elif msg.type == "function_call_output":
role = "user"
else:
logger.warning("Skipping unknown message type %r", msg.type)
continue

# if the effective role changed, finalize the previous turn.
if role != current_role:
Expand Down
89 changes: 59 additions & 30 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import os
from dataclasses import dataclass
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

import aioboto3 # type: ignore
from botocore.config import Config # type: ignore
import aioboto3
from aiobotocore.config import AioConfig

from livekit.agents import APIConnectionError, APIStatusError, llm
from livekit.agents.llm import (
Expand All @@ -40,6 +40,18 @@
from .log import logger
from .utils import to_fnc_ctx

if TYPE_CHECKING:
from types_aiobotocore_bedrock_runtime.type_defs import (
ConverseStreamOutputTypeDef,
ConverseStreamRequestTypeDef,
GuardrailStreamConfigurationTypeDef,
InferenceConfigurationTypeDef,
MessageUnionTypeDef,
SystemContentBlockTypeDef,
ToolConfigurationTypeDef,
)
else:
MessageUnionTypeDef = dict
DEFAULT_TEXT_MODEL = "anthropic.claude-3-5-sonnet-20240620-v1:0"


Expand All @@ -50,6 +62,7 @@ class _LLMOptions:
tool_choice: NotGivenOr[ToolChoice]
max_output_tokens: NotGivenOr[int]
top_p: NotGivenOr[float]
guardrail_config: NotGivenOr[GuardrailStreamConfigurationTypeDef]
additional_request_fields: NotGivenOr[dict[str, Any]]
cache_system: bool
cache_tools: bool
Expand All @@ -59,14 +72,15 @@ class LLM(llm.LLM):
def __init__(
self,
*,
model: NotGivenOr[str] = DEFAULT_TEXT_MODEL,
model: str = DEFAULT_TEXT_MODEL,
api_key: NotGivenOr[str] = NOT_GIVEN,
api_secret: NotGivenOr[str] = NOT_GIVEN,
region: NotGivenOr[str] = "us-east-1",
temperature: NotGivenOr[float] = NOT_GIVEN,
max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
top_p: NotGivenOr[float] = NOT_GIVEN,
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
guardrail_config: NotGivenOr[GuardrailStreamConfigurationTypeDef] = NOT_GIVEN,
additional_request_fields: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
cache_system: bool = False,
cache_tools: bool = False,
Expand Down Expand Up @@ -116,6 +130,7 @@ def __init__(
tool_choice=tool_choice,
max_output_tokens=max_output_tokens,
top_p=top_p,
guardrail_config=guardrail_config,
additional_request_fields=additional_request_fields,
cache_system=cache_system,
cache_tools=cache_tools,
Expand All @@ -140,13 +155,9 @@ def chat(
temperature: NotGivenOr[float] = NOT_GIVEN,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> LLMStream:
opts: dict[str, Any] = {}
extra_kwargs = extra_kwargs if is_given(extra_kwargs) else {}
opts: ConverseStreamRequestTypeDef = {"modelId": self._opts.model}

if is_given(self._opts.model):
opts["modelId"] = self._opts.model

def _get_tool_config() -> dict[str, Any] | None:
def _get_tool_config() -> ToolConfigurationTypeDef | None:
nonlocal tool_choice

if not tools:
Expand All @@ -156,7 +167,7 @@ def _get_tool_config() -> dict[str, Any] | None:
if self._opts.cache_tools:
tools_list.append({"cachePoint": {"type": "default"}})

tool_config: dict[str, Any] = {"tools": tools_list}
tool_config: ToolConfigurationTypeDef = {"tools": tools_list}
tool_choice = (
cast(ToolChoice, tool_choice) if is_given(tool_choice) else self._opts.tool_choice
)
Expand All @@ -175,17 +186,31 @@ def _get_tool_config() -> dict[str, Any] | None:
tool_config = _get_tool_config()
if tool_config:
opts["toolConfig"] = tool_config

messages, extra_data = chat_ctx.to_provider_format(format="aws")
opts["messages"] = messages
if is_given(self._opts.guardrail_config):
opts["guardrailConfig"] = self._opts.guardrail_config
# Selective guardrail: only guard the last user's message
for message in reversed(messages):
if message["role"] != "user":
continue

message["content"] = [
{"guardContent": {"text": block}} if "text" in block else block
for block in message["content"]
]
break
opts["messages"] = cast(list[MessageUnionTypeDef], messages)

if extra_data.system_messages:
system_messages: list[dict[str, str | dict]] = [
system_messages: list[SystemContentBlockTypeDef] = [
{"text": content} for content in extra_data.system_messages
]
if self._opts.cache_system:
system_messages.append({"cachePoint": {"type": "default"}})
opts["system"] = system_messages

inference_config: dict[str, Any] = {}
inference_config: InferenceConfigurationTypeDef = {}
if is_given(self._opts.max_output_tokens):
inference_config["maxTokens"] = self._opts.max_output_tokens
temperature = temperature if is_given(temperature) else self._opts.temperature
Expand All @@ -204,7 +229,7 @@ def _get_tool_config() -> dict[str, Any] | None:
tools=tools or [],
session=self._session,
conn_options=conn_options,
extra_kwargs=opts,
opts=opts,
)


Expand All @@ -217,21 +242,21 @@ def __init__(
session: aioboto3.Session,
conn_options: APIConnectOptions,
tools: list[FunctionTool | RawFunctionTool],
extra_kwargs: dict[str, Any],
opts: ConverseStreamRequestTypeDef,
) -> None:
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
self._llm: LLM = llm
self._opts = extra_kwargs
self._opts = opts
self._session = session
self._tool_call_id: str | None = None
self._fnc_name: str | None = None
self._fnc_raw_arguments: str | None = None
self._fnc_arg_parts: list[str] | None = None
self._text: str = ""

async def _run(self) -> None:
retryable = True
try:
config = Config(user_agent_extra="x-client-framework:livekit-plugins-aws")
config = AioConfig(user_agent_extra="x-client-framework:livekit-plugins-aws")
async with self._session.client("bedrock-runtime", config=config) as client:
response = await client.converse_stream(**self._opts)
request_id = response["ResponseMetadata"]["RequestId"]
Expand All @@ -254,17 +279,24 @@ async def _run(self) -> None:
retryable=retryable,
) from e

def _parse_chunk(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
if "contentBlockStart" in chunk:
tool_use = chunk["contentBlockStart"]["start"]["toolUse"]
def _parse_chunk(
self, request_id: str, chunk: ConverseStreamOutputTypeDef
) -> llm.ChatChunk | None:
if "contentBlockStart" in chunk and "toolUse" in (
start_block := chunk["contentBlockStart"]["start"]
):
tool_use = start_block["toolUse"]
self._tool_call_id = tool_use["toolUseId"]
self._fnc_name = tool_use["name"]
self._fnc_raw_arguments = ""
self._fnc_arg_parts = []

elif "contentBlockDelta" in chunk:
delta = chunk["contentBlockDelta"]["delta"]
if "toolUse" in delta:
self._fnc_raw_arguments += delta["toolUse"]["input"]
if self._fnc_arg_parts is None:
logger.warning("Received delta before block start")
self._fnc_arg_parts = []
self._fnc_arg_parts.append(delta["toolUse"]["input"])
elif "text" in delta:
return llm.ChatChunk(
id=request_id,
Expand All @@ -290,13 +322,10 @@ def _parse_chunk(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
)
elif "contentBlockStop" in chunk:
if self._tool_call_id:
if self._tool_call_id is None:
logger.warning("aws bedrock llm: no tool call id in the response")
return None
if self._fnc_name is None:
logger.warning("aws bedrock llm: no function name in the response")
return None
if self._fnc_raw_arguments is None:
if self._fnc_arg_parts is None:
logger.warning("aws bedrock llm: no function arguments in the response")
return None
chat_chunk = llm.ChatChunk(
Expand All @@ -305,13 +334,13 @@ def _parse_chunk(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
role="assistant",
tool_calls=[
FunctionToolCall(
arguments=self._fnc_raw_arguments,
arguments="".join(self._fnc_arg_parts),
name=self._fnc_name,
call_id=self._tool_call_id,
),
],
),
)
self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
self._tool_call_id = self._fnc_name = self._fnc_arg_parts = None
return chat_chunk
return None
10 changes: 5 additions & 5 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from dataclasses import dataclass, replace
from typing import cast

import aioboto3 # type: ignore
import botocore # type: ignore
import botocore.exceptions # type: ignore
from aiobotocore.config import AioConfig # type: ignore
import aioboto3
import botocore
import botocore.exceptions
from aiobotocore.config import AioConfig

from livekit.agents import (
APIConnectionError,
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(
voice=voice,
speech_engine=speech_engine,
text_type=text_type,
region=region or None,
region=region,
language=language or None,
sample_rate=sample_rate,
)
Expand Down
35 changes: 18 additions & 17 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from livekit.agents import llm
from livekit.agents.llm import FunctionTool, RawFunctionTool
from livekit.agents.llm.tool_context import (
Expand All @@ -8,36 +10,35 @@
is_raw_function_tool,
)

if TYPE_CHECKING:
from types_aiobotocore_bedrock_runtime.type_defs import ToolSpecificationTypeDef, ToolTypeDef
__all__ = ["to_fnc_ctx"]

DEFAULT_REGION = "us-east-1"


def to_fnc_ctx(fncs: list[FunctionTool | RawFunctionTool]) -> list[dict]:
def to_fnc_ctx(fncs: list[FunctionTool | RawFunctionTool]) -> list[ToolTypeDef]:
return [_build_tool_spec(fnc) for fnc in fncs]


def _build_tool_spec(function: FunctionTool | RawFunctionTool) -> dict:
def _build_tool_spec(function: FunctionTool | RawFunctionTool) -> ToolTypeDef:
if is_function_tool(function):
fnc = llm.utils.build_legacy_openai_schema(function, internally_tagged=True)
return {
"toolSpec": _strip_nones(
{
"name": fnc["name"],
"description": fnc["description"] if fnc["description"] else None,
"inputSchema": {"json": fnc["parameters"] if fnc["parameters"] else {}},
}
)
spec: ToolSpecificationTypeDef = {
"name": fnc["name"],
"inputSchema": {"json": fnc["parameters"] if fnc["parameters"] else {}},
}
if fnc["description"]:
spec["description"] = fnc["description"]
return {"toolSpec": spec}
elif is_raw_function_tool(function):
info = get_raw_function_info(function)
return {
"toolSpec": _strip_nones(
{
"name": info.name,
"description": info.raw_schema.get("description", ""),
"inputSchema": {"json": info.raw_schema.get("parameters", {})},
}
)
"toolSpec": {
"name": info.name,
"description": info.raw_schema.get("description", ""),
"inputSchema": {"json": info.raw_schema.get("parameters", {})},
}
}
else:
raise ValueError("Invalid function tool")
Expand Down
7 changes: 6 additions & 1 deletion livekit-plugins/livekit-plugins-aws/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"livekit-agents>=1.2.14",
"aioboto3>=14.1.0",
"aioboto3>=15.2.0",
"amazon-transcribe>=0.6.4",
]

Expand All @@ -47,3 +47,8 @@ packages = ["livekit"]

[tool.hatch.build.targets.sdist]
include = ["/livekit"]

[dependency-groups]
dev = [
"types-aioboto3[bedrock-runtime]>=15.2.0",
]