Skip to content

Commit

Permalink
Merge pull request #308 from TEN-framework/feat/tools
Browse files Browse the repository at this point in the history
Feat/tools
  • Loading branch information
tomasliu-agora authored Oct 6, 2024
2 parents bedb98b + 59ee505 commit f672a15
Show file tree
Hide file tree
Showing 13 changed files with 477 additions and 25 deletions.
3 changes: 2 additions & 1 deletion agents/manifest-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@
"type": "system",
"name": "nlohmann_json",
"version": "3.11.2",
"hash": "72b15822c7ea9deef5e7ad96216ac55e93f11b00466dd1943afd5ee276e99d19"
"hash": "72b15822c7ea9deef5e7ad96216ac55e93f11b00466dd1943afd5ee276e99d19",
"supports": []
},
{
"type": "system",
Expand Down
33 changes: 33 additions & 0 deletions agents/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -2201,6 +2201,15 @@
"extension_group": "transcriber",
"addon": "message_collector",
"name": "message_collector"
},
{
"type": "extension",
"extension_group": "tools",
"addon": "weatherapi_tool_python",
"name": "weatherapi_tool_python",
"property": {
"api_key": "${env:WEATHERAPI_API_KEY}"
}
}
],
"connections": [
Expand All @@ -2219,6 +2228,21 @@
}
]
},
{
"extension_group": "tools",
"extension": "weatherapi_tool_python",
"cmd": [
{
"name": "tool_register",
"dest": [
{
"extension_group": "llm",
"extension": "openai_v2v_python"
}
]
}
]
},
{
"extension_group": "llm",
"extension": "openai_v2v_python",
Expand Down Expand Up @@ -2253,6 +2277,15 @@
"extension": "agora_rtc"
}
]
},
{
"name": "tool_call",
"dest": [
{
"extension_group": "tools",
"extension": "weatherapi_tool_python"
}
]
}
]
},
Expand Down
87 changes: 83 additions & 4 deletions agents/ten_packages/extension/openai_v2v_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import threading
import base64
from datetime import datetime
from typing import Awaitable
from functools import partial

from ten import (
AudioFrame,
Expand All @@ -22,9 +24,12 @@
)
from ten.audio_frame import AudioFrameDataFmt
from .log import logger

from .tools import ToolRegistry
from .conf import RealtimeApiConfig, BASIC_PROMPT
from .realtime.connection import RealtimeApiConnection
from .realtime.struct import *
from .tools import ToolRegistry

# properties
PROPERTY_API_KEY = "api_key" # Required
Expand All @@ -41,6 +46,15 @@

DEFAULT_VOICE = Voices.Alloy

CMD_TOOL_REGISTER = "tool_register"
CMD_TOOL_CALL = "tool_call"
CMD_PROPERTY_NAME = "name"
CMD_PROPERTY_ARGS = "args"

TOOL_REGISTER_PROPERTY_NAME = "name"
TOOL_REGISTER_PROPERTY_DESCRIPTON = "description"
TOOL_REGISTER_PROPERTY_PARAMETERS = "parameters"

class Role(str, Enum):
User = "user"
Assistant = "assistant"
Expand All @@ -59,6 +73,7 @@ def __init__(self, name: str):
self.connected: bool = False
self.session_id: str = ""
self.session: SessionUpdateParams = None
self.last_updated = None
self.ctx: dict = {}

# audo related
Expand All @@ -71,6 +86,7 @@ def __init__(self, name: str):
self.remote_stream_id: int = 0
self.channel_name: str = ""
self.dump: bool = False
self.registry = ToolRegistry()

def on_start(self, ten_env: TenEnv) -> None:
logger.info("OpenAIV2VExtension on_start")
Expand All @@ -85,6 +101,8 @@ def start_event_loop(loop):
self.thread = threading.Thread(
target=start_event_loop, args=(self.loop,))
self.thread.start()

self._register_local_tools()

asyncio.run_coroutine_threadsafe(self._init_connection(), self.loop)

Expand All @@ -105,7 +123,7 @@ def on_stop(self, ten_env: TenEnv) -> None:
def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
try:
stream_id = audio_frame.get_property_int("stream_id")
# logger.debug(f"on_audio_frame {stream_id}")
#logger.debug(f"on_audio_frame {stream_id}")
if self.channel_name == "":
self.channel_name = audio_frame.get_property_string("channel")

Expand All @@ -129,13 +147,22 @@ def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:

# Should not be here
def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()

if cmd_name == CMD_TOOL_REGISTER:
self._on_tool_register(ten_env, cmd)

cmd_result = CmdResult.create(StatusCode.OK)
ten_env.return_result(cmd_result, cmd)

# Should not be here
def on_data(self, ten_env: TenEnv, data: Data) -> None:
pass

def on_config_changed(self) -> None:
# update session again
return

async def _init_connection(self):
try:
self.conn = RealtimeApiConnection(
Expand Down Expand Up @@ -173,7 +200,7 @@ def get_time_ms() -> int:

text = self._greeting_text()
await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": text}])))
await self.conn.send_request(ResponseCreate(response=ResponseCreateParams()))
await self.conn.send_request(ResponseCreate())

# update_conversation = self.update_conversation()
# await self.conn.send_request(update_conversation)
Expand Down Expand Up @@ -222,7 +249,7 @@ def get_time_ms() -> int:
logger.info(f"Output item done {message.item}")
case ResponseOutputItemAdded():
logger.info(
f"Output item added {message.output_index} {message.item.id}")
f"Output item added {message.output_index} {message.item}")
case ResponseAudioDelta():
if message.response_id in flushed:
logger.warning(
Expand Down Expand Up @@ -253,6 +280,12 @@ def get_time_ms() -> int:
relative_start_ms = get_time_ms() - message.audio_end_ms
logger.info(
f"On server stop listening, {message.audio_end_ms}, relative {relative_start_ms}")
case ResponseFunctionCallArgumentsDone():
tool_call_id = message.call_id
name = message.name
arguments = message.arguments
logger.info(f"need to call func {name}")
await self.registry.on_func_call(tool_call_id, name, arguments, self._on_tool_output) # TODO rebuild this into async, or it will block the thread
case ErrorMessage():
logger.error(
f"Error message received: {message.error}")
Expand Down Expand Up @@ -365,12 +398,15 @@ def _fetch_properties(self, ten_env: TenEnv):

def _update_session(self) -> SessionUpdate:
prompt = self._replace(self.config.instruction)
self.last_updated = datetime.now()
return SessionUpdate(session=SessionUpdateParams(
instructions=prompt,
model=self.config.model,
voice=self.config.voice,
input_audio_transcription=InputAudioTranscription(
model="whisper-1")
model="whisper-1"),
tool_choice="auto",
tools=self.registry.get_tools()
))

'''
Expand Down Expand Up @@ -439,6 +475,49 @@ def _dump_audio_if_need(self, buf: bytearray, role: Role) -> None:
with open("{}_{}.pcm".format(role, self.channel_name), "ab") as dump_file:
dump_file.write(buf)

def _register_local_tools(self) -> None:
self.ctx["tools"] = self.registry.to_prompt()

def _on_tool_register(self, ten_env: TenEnv, cmd: Cmd):
try:
name = cmd.get_property_string(TOOL_REGISTER_PROPERTY_NAME)
description = cmd.get_property_string(TOOL_REGISTER_PROPERTY_DESCRIPTON)
pstr = cmd.get_property_string(TOOL_REGISTER_PROPERTY_PARAMETERS)
parameters = json.loads(pstr)
p = partial(self._remote_tool_call, ten_env)
self.registry.register(
name=name, description=description,
callback=p,
parameters=parameters)
logger.info(f"on tool register {name} {description}")
self.on_config_changed()
except:
logger.exception(f"Failed to register")

async def _remote_tool_call(self, ten_env: TenEnv, name:str, args: str, callback: Awaitable):
logger.info(f"_remote_tool_call {name} {args}")
c = Cmd.create(CMD_TOOL_CALL)
c.set_property_string(CMD_PROPERTY_NAME, name)
c.set_property_string(CMD_PROPERTY_ARGS, args)
ten_env.send_cmd(c, lambda ten, result: asyncio.run_coroutine_threadsafe(
callback(result.get_property_string("response")), self.loop))
logger.info(f"_remote_tool_call finish {name} {args}")

async def _on_tool_output(self, tool_call_id:str, result: str):
logger.info(f"_on_tool_output {tool_call_id} {result}")
try:
tool_response = ItemCreate(
item=FunctionCallOutputItemParam(
call_id=tool_call_id,
output=result,
)
)

await self.conn.send_request(tool_response)
await self.conn.send_request(ResponseCreate())
except:
logger.exception("Failed to handle tool output")

def _greeting_text(self) -> str:
text = "Hi, there."
if self.config.language == "zh-CN":
Expand Down
20 changes: 0 additions & 20 deletions agents/ten_packages/extension/openai_v2v_python/id.py

This file was deleted.

91 changes: 91 additions & 0 deletions agents/ten_packages/extension/openai_v2v_python/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import copy
from typing import Dict, Any
from functools import partial

from .log import logger

class ToolRegistry:
tools: Dict[str, dict[str, Any]] = {}
def register(self, name:str, description: str, callback, parameters: Any = None) -> None:
info = {
"type": "function",
"name": name,
"description": description,
"callback": callback
}
if parameters:
info["parameters"] = parameters
self.tools[name] = info
logger.info(f"register tool {name} {description}")

def to_prompt(self) -> str:
prompt = ""
if self.tools:
prompt = "You have several tools that you can get help from:\n"
for name, t in self.tools.items():
desc = t["description"]
prompt += f"- ***{name}***: {desc}"
return prompt

def unregister(self, name:str) -> None:
if name in self.tools:
del self.tools[name]
logger.info(f"unregister tool {name}")

def get_tools(self) -> list[dict[str, Any]]:
result = []
for _, t in self.tools.items():
info = copy.copy(t)
del info["callback"]
result.append(info)
return result

async def on_func_call(self, call_id: str, name: str, args: str, callback):
try:
if name in self.tools:
t = self.tools[name]
# FIXME add args check
if t.get("callback"):
p = partial(callback, call_id)
await t["callback"](name, args, p)
else:
logger.warning(f"Failed to find func {name}")
except:
logger.exception(f"Failed to call func {name}")
# TODO What to do if func call is dead
callback(None)

if __name__ == "__main__":
r = ToolRegistry()

def weather_check(location:str = "", datetime:str = ""):
logger.info(f"on weather check {location}, {datetime}")

def on_tool_completion(result: Any):
logger.info(f"on tool completion {result}")

r.register(
name="weather", description="This is a weather check func, if the user is asking about the weather. you need to summarize location and time information from the context as parameters. if the information is lack, please ask for more detail before calling.",
callback=weather_check,
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location or region for the weather check.",
},
"datetime": {
"type": "string",
"description": "The date and time for the weather check. The datetime should use format like 2024-10-01T16:42:00.",
}
},
"required": ["location"],
})
print(r.to_prompt())
print(r.get_tools())
print(r.on_func_call("weather", {"location":"LA", "datetime":"2024-10-01T16:43:01"}, on_tool_completion))
r.unregister("weather")
print(r.to_prompt())
print(r.get_tools())
print(r.on_func_call("weather", {"location":"LA", "datetime":"2024-10-01T16:43:01"}, on_tool_completion))

21 changes: 21 additions & 0 deletions agents/ten_packages/extension/weatherapi_tool_python/BUILD.gn
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2022-11.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import("//build/feature/ten_package.gni")

ten_package("weatherapi_tool_python") {
package_kind = "extension"

resources = [
"__init__.py",
"addon.py",
"extension.py",
"log.py",
"manifest.json",
"property.json",
]
}
Loading

0 comments on commit f672a15

Please sign in to comment.