diff --git a/agents/manifest-lock.json b/agents/manifest-lock.json index 902dca63..c4921982 100644 --- a/agents/manifest-lock.json +++ b/agents/manifest-lock.json @@ -141,7 +141,8 @@ "type": "system", "name": "nlohmann_json", "version": "3.11.2", - "hash": "72b15822c7ea9deef5e7ad96216ac55e93f11b00466dd1943afd5ee276e99d19" + "hash": "72b15822c7ea9deef5e7ad96216ac55e93f11b00466dd1943afd5ee276e99d19", + "supports": [] }, { "type": "system", diff --git a/agents/property.json b/agents/property.json index 10371751..68888fb1 100644 --- a/agents/property.json +++ b/agents/property.json @@ -2201,6 +2201,15 @@ "extension_group": "transcriber", "addon": "message_collector", "name": "message_collector" + }, + { + "type": "extension", + "extension_group": "tools", + "addon": "weatherapi_tool_python", + "name": "weatherapi_tool_python", + "property": { + "api_key": "${env:WEATHERAPI_API_KEY}" + } } ], "connections": [ @@ -2219,6 +2228,21 @@ } ] }, + { + "extension_group": "tools", + "extension": "weatherapi_tool_python", + "cmd": [ + { + "name": "tool_register", + "dest": [ + { + "extension_group": "llm", + "extension": "openai_v2v_python" + } + ] + } + ] + }, { "extension_group": "llm", "extension": "openai_v2v_python", @@ -2253,6 +2277,15 @@ "extension": "agora_rtc" } ] + }, + { + "name": "tool_call", + "dest": [ + { + "extension_group": "tools", + "extension": "weatherapi_tool_python" + } + ] } ] }, diff --git a/agents/ten_packages/extension/openai_v2v_python/extension.py b/agents/ten_packages/extension/openai_v2v_python/extension.py index 881ab99c..1457498a 100644 --- a/agents/ten_packages/extension/openai_v2v_python/extension.py +++ b/agents/ten_packages/extension/openai_v2v_python/extension.py @@ -9,6 +9,8 @@ import threading import base64 from datetime import datetime +from typing import Awaitable +from functools import partial from ten import ( AudioFrame, @@ -22,9 +24,12 @@ ) from ten.audio_frame import AudioFrameDataFmt from .log import logger + +from .tools import ToolRegistry from .conf import RealtimeApiConfig, BASIC_PROMPT from .realtime.connection import RealtimeApiConnection from .realtime.struct import * +from .tools import ToolRegistry # properties PROPERTY_API_KEY = "api_key" # Required @@ -41,6 +46,15 @@ DEFAULT_VOICE = Voices.Alloy +CMD_TOOL_REGISTER = "tool_register" +CMD_TOOL_CALL = "tool_call" +CMD_PROPERTY_NAME = "name" +CMD_PROPERTY_ARGS = "args" + +TOOL_REGISTER_PROPERTY_NAME = "name" +TOOL_REGISTER_PROPERTY_DESCRIPTON = "description" +TOOL_REGISTER_PROPERTY_PARAMETERS = "parameters" + class Role(str, Enum): User = "user" Assistant = "assistant" @@ -59,6 +73,7 @@ def __init__(self, name: str): self.connected: bool = False self.session_id: str = "" self.session: SessionUpdateParams = None + self.last_updated = None self.ctx: dict = {} # audo related @@ -71,6 +86,7 @@ def __init__(self, name: str): self.remote_stream_id: int = 0 self.channel_name: str = "" self.dump: bool = False + self.registry = ToolRegistry() def on_start(self, ten_env: TenEnv) -> None: logger.info("OpenAIV2VExtension on_start") @@ -85,6 +101,8 @@ def start_event_loop(loop): self.thread = threading.Thread( target=start_event_loop, args=(self.loop,)) self.thread.start() + + self._register_local_tools() asyncio.run_coroutine_threadsafe(self._init_connection(), self.loop) @@ -105,7 +123,7 @@ def on_stop(self, ten_env: TenEnv) -> None: def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: try: stream_id = audio_frame.get_property_int("stream_id") - # logger.debug(f"on_audio_frame {stream_id}") + #logger.debug(f"on_audio_frame {stream_id}") if self.channel_name == "": self.channel_name = audio_frame.get_property_string("channel") @@ -129,6 +147,11 @@ def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: # Should not be here def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: + cmd_name = cmd.get_name() + + if cmd_name == CMD_TOOL_REGISTER: + self._on_tool_register(ten_env, cmd) + cmd_result = CmdResult.create(StatusCode.OK) ten_env.return_result(cmd_result, cmd) @@ -136,6 +159,10 @@ def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: def on_data(self, ten_env: TenEnv, data: Data) -> None: pass + def on_config_changed(self) -> None: + # update session again + return + async def _init_connection(self): try: self.conn = RealtimeApiConnection( @@ -173,7 +200,7 @@ def get_time_ms() -> int: text = self._greeting_text() await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": text}]))) - await self.conn.send_request(ResponseCreate(response=ResponseCreateParams())) + await self.conn.send_request(ResponseCreate()) # update_conversation = self.update_conversation() # await self.conn.send_request(update_conversation) @@ -222,7 +249,7 @@ def get_time_ms() -> int: logger.info(f"Output item done {message.item}") case ResponseOutputItemAdded(): logger.info( - f"Output item added {message.output_index} {message.item.id}") + f"Output item added {message.output_index} {message.item}") case ResponseAudioDelta(): if message.response_id in flushed: logger.warning( @@ -253,6 +280,12 @@ def get_time_ms() -> int: relative_start_ms = get_time_ms() - message.audio_end_ms logger.info( f"On server stop listening, {message.audio_end_ms}, relative {relative_start_ms}") + case ResponseFunctionCallArgumentsDone(): + tool_call_id = message.call_id + name = message.name + arguments = message.arguments + logger.info(f"need to call func {name}") + await self.registry.on_func_call(tool_call_id, name, arguments, self._on_tool_output) # TODO rebuild this into async, or it will block the thread case ErrorMessage(): logger.error( f"Error message received: {message.error}") @@ -365,12 +398,15 @@ def _fetch_properties(self, ten_env: TenEnv): def _update_session(self) -> SessionUpdate: prompt = self._replace(self.config.instruction) + self.last_updated = datetime.now() return SessionUpdate(session=SessionUpdateParams( instructions=prompt, model=self.config.model, voice=self.config.voice, input_audio_transcription=InputAudioTranscription( - model="whisper-1") + model="whisper-1"), + tool_choice="auto", + tools=self.registry.get_tools() )) ''' @@ -439,6 +475,49 @@ def _dump_audio_if_need(self, buf: bytearray, role: Role) -> None: with open("{}_{}.pcm".format(role, self.channel_name), "ab") as dump_file: dump_file.write(buf) + def _register_local_tools(self) -> None: + self.ctx["tools"] = self.registry.to_prompt() + + def _on_tool_register(self, ten_env: TenEnv, cmd: Cmd): + try: + name = cmd.get_property_string(TOOL_REGISTER_PROPERTY_NAME) + description = cmd.get_property_string(TOOL_REGISTER_PROPERTY_DESCRIPTON) + pstr = cmd.get_property_string(TOOL_REGISTER_PROPERTY_PARAMETERS) + parameters = json.loads(pstr) + p = partial(self._remote_tool_call, ten_env) + self.registry.register( + name=name, description=description, + callback=p, + parameters=parameters) + logger.info(f"on tool register {name} {description}") + self.on_config_changed() + except: + logger.exception(f"Failed to register") + + async def _remote_tool_call(self, ten_env: TenEnv, name:str, args: str, callback: Awaitable): + logger.info(f"_remote_tool_call {name} {args}") + c = Cmd.create(CMD_TOOL_CALL) + c.set_property_string(CMD_PROPERTY_NAME, name) + c.set_property_string(CMD_PROPERTY_ARGS, args) + ten_env.send_cmd(c, lambda ten, result: asyncio.run_coroutine_threadsafe( + callback(result.get_property_string("response")), self.loop)) + logger.info(f"_remote_tool_call finish {name} {args}") + + async def _on_tool_output(self, tool_call_id:str, result: str): + logger.info(f"_on_tool_output {tool_call_id} {result}") + try: + tool_response = ItemCreate( + item=FunctionCallOutputItemParam( + call_id=tool_call_id, + output=result, + ) + ) + + await self.conn.send_request(tool_response) + await self.conn.send_request(ResponseCreate()) + except: + logger.exception("Failed to handle tool output") + def _greeting_text(self) -> str: text = "Hi, there." if self.config.language == "zh-CN": diff --git a/agents/ten_packages/extension/openai_v2v_python/id.py b/agents/ten_packages/extension/openai_v2v_python/id.py deleted file mode 100644 index e8fe56c8..00000000 --- a/agents/ten_packages/extension/openai_v2v_python/id.py +++ /dev/null @@ -1,20 +0,0 @@ -import random -import string - - -def generate_rand_str(prefix: str, len: int = 16) -> str: - # Generate a random string of specified length with the given prefix - random_str = "".join(random.choices(string.ascii_letters + string.digits, k=len)) - return f"{prefix}_{random_str}" - - -def generate_client_event_id() -> str: - return generate_rand_str("cevt") - - -def generate_event_id() -> str: - return generate_rand_str("event") - - -def generate_response_id() -> str: - return generate_rand_str("resp") diff --git a/agents/ten_packages/extension/openai_v2v_python/tools.py b/agents/ten_packages/extension/openai_v2v_python/tools.py new file mode 100644 index 00000000..a3e532d9 --- /dev/null +++ b/agents/ten_packages/extension/openai_v2v_python/tools.py @@ -0,0 +1,91 @@ +import copy +from typing import Dict, Any +from functools import partial + +from .log import logger + +class ToolRegistry: + tools: Dict[str, dict[str, Any]] = {} + def register(self, name:str, description: str, callback, parameters: Any = None) -> None: + info = { + "type": "function", + "name": name, + "description": description, + "callback": callback + } + if parameters: + info["parameters"] = parameters + self.tools[name] = info + logger.info(f"register tool {name} {description}") + + def to_prompt(self) -> str: + prompt = "" + if self.tools: + prompt = "You have several tools that you can get help from:\n" + for name, t in self.tools.items(): + desc = t["description"] + prompt += f"- ***{name}***: {desc}" + return prompt + + def unregister(self, name:str) -> None: + if name in self.tools: + del self.tools[name] + logger.info(f"unregister tool {name}") + + def get_tools(self) -> list[dict[str, Any]]: + result = [] + for _, t in self.tools.items(): + info = copy.copy(t) + del info["callback"] + result.append(info) + return result + + async def on_func_call(self, call_id: str, name: str, args: str, callback): + try: + if name in self.tools: + t = self.tools[name] + # FIXME add args check + if t.get("callback"): + p = partial(callback, call_id) + await t["callback"](name, args, p) + else: + logger.warning(f"Failed to find func {name}") + except: + logger.exception(f"Failed to call func {name}") + # TODO What to do if func call is dead + callback(None) + +if __name__ == "__main__": + r = ToolRegistry() + + def weather_check(location:str = "", datetime:str = ""): + logger.info(f"on weather check {location}, {datetime}") + + def on_tool_completion(result: Any): + logger.info(f"on tool completion {result}") + + r.register( + name="weather", description="This is a weather check func, if the user is asking about the weather. you need to summarize location and time information from the context as parameters. if the information is lack, please ask for more detail before calling.", + callback=weather_check, + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location or region for the weather check.", + }, + "datetime": { + "type": "string", + "description": "The date and time for the weather check. The datetime should use format like 2024-10-01T16:42:00.", + } + }, + "required": ["location"], + }) + print(r.to_prompt()) + print(r.get_tools()) + print(r.on_func_call("weather", {"location":"LA", "datetime":"2024-10-01T16:43:01"}, on_tool_completion)) + r.unregister("weather") + print(r.to_prompt()) + print(r.get_tools()) + print(r.on_func_call("weather", {"location":"LA", "datetime":"2024-10-01T16:43:01"}, on_tool_completion)) + \ No newline at end of file diff --git a/agents/ten_packages/extension/weatherapi_tool_python/BUILD.gn b/agents/ten_packages/extension/weatherapi_tool_python/BUILD.gn new file mode 100644 index 00000000..15a31a94 --- /dev/null +++ b/agents/ten_packages/extension/weatherapi_tool_python/BUILD.gn @@ -0,0 +1,21 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2022-11. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +import("//build/feature/ten_package.gni") + +ten_package("weatherapi_tool_python") { + package_kind = "extension" + + resources = [ + "__init__.py", + "addon.py", + "extension.py", + "log.py", + "manifest.json", + "property.json", + ] +} diff --git a/agents/ten_packages/extension/weatherapi_tool_python/README.md b/agents/ten_packages/extension/weatherapi_tool_python/README.md new file mode 100644 index 00000000..045d9de7 --- /dev/null +++ b/agents/ten_packages/extension/weatherapi_tool_python/README.md @@ -0,0 +1,29 @@ +# weatherapi_tool_python + + + +## Features + + + +- xxx feature + +## API + +Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). + + + +## Development + +### Build + + + +### Unit test + + + +## Misc + + diff --git a/agents/ten_packages/extension/weatherapi_tool_python/__init__.py b/agents/ten_packages/extension/weatherapi_tool_python/__init__.py new file mode 100644 index 00000000..4a1c0614 --- /dev/null +++ b/agents/ten_packages/extension/weatherapi_tool_python/__init__.py @@ -0,0 +1,11 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2024-08. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +from . import addon +from .log import logger + +logger.info("weatherapi_tool_python extension loaded") diff --git a/agents/ten_packages/extension/weatherapi_tool_python/addon.py b/agents/ten_packages/extension/weatherapi_tool_python/addon.py new file mode 100644 index 00000000..b37c9546 --- /dev/null +++ b/agents/ten_packages/extension/weatherapi_tool_python/addon.py @@ -0,0 +1,22 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2024-08. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +from ten import ( + Addon, + register_addon_as_extension, + TenEnv, +) +from .extension import WeatherToolExtension +from .log import logger + + +@register_addon_as_extension("weatherapi_tool_python") +class WeatherToolExtensionAddon(Addon): + + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + logger.info("WeatherToolExtensionAddon on_create_instance") + ten_env.on_create_instance_done(WeatherToolExtension(name), context) diff --git a/agents/ten_packages/extension/weatherapi_tool_python/extension.py b/agents/ten_packages/extension/weatherapi_tool_python/extension.py new file mode 100644 index 00000000..d9485062 --- /dev/null +++ b/agents/ten_packages/extension/weatherapi_tool_python/extension.py @@ -0,0 +1,139 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2024-08. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# + +import json +import requests + +from typing import Any + +from ten import ( + AudioFrame, + VideoFrame, + Extension, + TenEnv, + Cmd, + StatusCode, + CmdResult, + Data, +) +from .log import logger + +CMD_TOOL_REGISTER = "tool_register" +CMD_TOOL_CALL = "tool_call" +CMD_PROPERTY_NAME = "name" +CMD_PROPERTY_ARGS = "args" + +TOOL_REGISTER_PROPERTY_NAME = "name" +TOOL_REGISTER_PROPERTY_DESCRIPTON = "description" +TOOL_REGISTER_PROPERTY_PARAMETERS = "parameters" + +TOOL_NAME = "get_current_weather" +TOOL_DESCRIPTION = "Determine weather in my location" +TOOL_PARAMETERS = { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + } + }, + "required": ["location"], + } + +PROPERTY_API_KEY = "api_key" # Required + +class WeatherToolExtension(Extension): + api_key: str = "" + + def on_init(self, ten_env: TenEnv) -> None: + logger.info("WeatherToolExtension on_init") + + ten_env.on_init_done() + + def on_start(self, ten_env: TenEnv) -> None: + logger.info("WeatherToolExtension on_start") + + try: + api_key = ten_env.get_property_string(PROPERTY_API_KEY) + self.api_key = api_key + except Exception as err: + logger.info( + f"GetProperty required {PROPERTY_API_KEY} failed, err: {err}") + return + + # Register func + c = Cmd.create(CMD_TOOL_REGISTER) + c.set_property_string(TOOL_REGISTER_PROPERTY_NAME, TOOL_NAME) + c.set_property_string(TOOL_REGISTER_PROPERTY_DESCRIPTON, TOOL_DESCRIPTION) + c.set_property_string(TOOL_REGISTER_PROPERTY_PARAMETERS, json.dumps(TOOL_PARAMETERS)) + ten_env.send_cmd(c, lambda ten, result: logger.info(f"register done, {result}")) + + ten_env.on_start_done() + + def on_stop(self, ten_env: TenEnv) -> None: + logger.info("WeatherToolExtension on_stop") + + ten_env.on_stop_done() + + def on_deinit(self, ten_env: TenEnv) -> None: + logger.info("WeatherToolExtension on_deinit") + ten_env.on_deinit_done() + + def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: + cmd_name = cmd.get_name() + logger.info(f"on_cmd name {cmd_name} {cmd.to_json()}") + + try: + name = cmd.get_property_string(CMD_PROPERTY_NAME) + if name == TOOL_NAME: + try: + args = cmd.get_property_string(CMD_PROPERTY_ARGS) + arg_dict = json.loads(args) + if "location" in arg_dict: + logger.info(f"before get current weather {name}") + resp = self._get_current_weather(arg_dict["location"]) + logger.info(f"after get current weather {resp}") + cmd_result = CmdResult.create(StatusCode.OK) + cmd_result.set_property_string("response", json.dumps(resp)) + ten_env.return_result(cmd_result, cmd) + return + else: + logger.error(f"no location in args {args}") + cmd_result = CmdResult.create(StatusCode.ERROR) + ten_env.return_result(cmd_result, cmd) + return + except: + logger.exception("Failed to get weather") + cmd_result = CmdResult.create(StatusCode.ERROR) + ten_env.return_result(cmd_result, cmd) + return + else: + logger.error(f"unknown tool name {name}") + except: + logger.exception("Failed to get tool name") + cmd_result = CmdResult.create(StatusCode.ERROR) + ten_env.return_result(cmd_result, cmd) + return + + cmd_result = CmdResult.create(StatusCode.OK) + ten_env.return_result(cmd_result, cmd) + + def on_data(self, ten_env: TenEnv, data: Data) -> None: + pass + + def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: + pass + + def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: + pass + + def _get_current_weather(self, location:str) -> Any: + url = f"http://api.weatherapi.com/v1/current.json?key={self.api_key}&q={location}&aqi=no" + response = requests.get(url) + result = response.json() + return result \ No newline at end of file diff --git a/agents/ten_packages/extension/weatherapi_tool_python/log.py b/agents/ten_packages/extension/weatherapi_tool_python/log.py new file mode 100644 index 00000000..24afd88c --- /dev/null +++ b/agents/ten_packages/extension/weatherapi_tool_python/log.py @@ -0,0 +1,22 @@ +# +# +# Agora Real Time Engagement +# Created by Wei Hu in 2024-08. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +import logging + +logger = logging.getLogger("weatherapi_tool_python") +logger.setLevel(logging.INFO) + +formatter_str = ( + "%(asctime)s - %(name)s - %(levelname)s - %(process)d - " + "[%(filename)s:%(lineno)d] - %(message)s" +) +formatter = logging.Formatter(formatter_str) + +console_handler = logging.StreamHandler() +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/weatherapi_tool_python/manifest.json b/agents/ten_packages/extension/weatherapi_tool_python/manifest.json new file mode 100644 index 00000000..c6c359d6 --- /dev/null +++ b/agents/ten_packages/extension/weatherapi_tool_python/manifest.json @@ -0,0 +1,23 @@ +{ + "type": "extension", + "name": "weatherapi_tool_python", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.2" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "BUILD.gn", + "**.tent", + "**.py", + "README.md" + ] + }, + "api": {} +} \ No newline at end of file diff --git a/agents/ten_packages/extension/weatherapi_tool_python/property.json b/agents/ten_packages/extension/weatherapi_tool_python/property.json new file mode 100644 index 00000000..9e26dfee --- /dev/null +++ b/agents/ten_packages/extension/weatherapi_tool_python/property.json @@ -0,0 +1 @@ +{} \ No newline at end of file