From 71086d77862c58de55f91cbacf0ca5802bec40b5 Mon Sep 17 00:00:00 2001 From: TomasLiu Date: Tue, 19 Nov 2024 13:30:32 +0800 Subject: [PATCH 1/2] refact glue and use base config --- .../deepgram_asr_python/extension.py | 83 ++-- .../deepgram_asr_python/manifest.json | 3 +- .../extension/glue_python_async/__init__.py | 2 - .../extension/glue_python_async/addon.py | 3 +- .../examples/openai_wrapper.py | 169 +++++--- .../extension/glue_python_async/extension.py | 377 ++++++++++++------ .../extension/glue_python_async/log.py | 20 - .../extension/glue_python_async/schema.yml | 307 ++++++++++++-- .../interrupt_detector_python/__init__.py | 5 +- .../{interrupt_detector_addon.py => addon.py} | 6 +- ...upt_detector_extension.py => extension.py} | 23 +- .../interrupt_detector_python/log.py | 13 - .../message_collector_rtm/__init__.py | 3 - .../message_collector_rtm/manifest.json | 18 - .../message_collector_rtm/src/addon.py | 3 +- .../message_collector_rtm/src/extension.py | 6 +- .../message_collector_rtm/src/log.py | 22 - .../weatherapi_tool_python/__init__.py | 3 - .../extension/weatherapi_tool_python/addon.py | 3 +- .../weatherapi_tool_python/extension.py | 44 +- .../extension/weatherapi_tool_python/log.py | 22 - 21 files changed, 731 insertions(+), 404 deletions(-) delete mode 100644 agents/ten_packages/extension/glue_python_async/log.py rename agents/ten_packages/extension/interrupt_detector_python/{interrupt_detector_addon.py => addon.py} (75%) rename agents/ten_packages/extension/interrupt_detector_python/{interrupt_detector_extension.py => extension.py} (83%) delete mode 100644 agents/ten_packages/extension/interrupt_detector_python/log.py delete mode 100644 agents/ten_packages/extension/message_collector_rtm/src/log.py delete mode 100644 agents/ten_packages/extension/weatherapi_tool_python/log.py diff --git a/agents/ten_packages/extension/deepgram_asr_python/extension.py b/agents/ten_packages/extension/deepgram_asr_python/extension.py index 91248c55..a21e8b83 100644 --- a/agents/ten_packages/extension/deepgram_asr_python/extension.py +++ b/agents/ten_packages/extension/deepgram_asr_python/extension.py @@ -11,26 +11,34 @@ import asyncio from deepgram import AsyncListenWebSocketClient, DeepgramClientOptions, LiveTranscriptionEvents, LiveOptions +from dataclasses import dataclass -from .config import DeepgramConfig - -PROPERTY_API_KEY = "api_key" # Required -PROPERTY_LANG = "language" # Optional -PROPERTY_MODEL = "model" # Optional -PROPERTY_SAMPLE_RATE = "sample_rate" # Optional +from ten_ai_base import BaseConfig DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id" DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment" +@dataclass +class DeepgramASRConfig(BaseConfig): + api_key: str = "" + language: str = "en-US" + model: str = "nova-2" + sample_rate: int = 16000 + + channels: int = 1 + encoding: str = 'linear16' + interim_results: bool = True + punctuate: bool = True + class DeepgramASRExtension(AsyncExtension): def __init__(self, name: str): super().__init__(name) self.stopped = False - self.deepgram_client : AsyncListenWebSocketClient = None - self.deepgram_config : DeepgramConfig = None + self.client : AsyncListenWebSocketClient = None + self.config : DeepgramASRConfig = None self.ten_env : AsyncTenEnv = None async def on_init(self, ten_env: AsyncTenEnv) -> None: @@ -41,30 +49,15 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: self.loop = asyncio.get_event_loop() self.ten_env = ten_env - self.deepgram_config = DeepgramConfig.default_config() + self.config = DeepgramASRConfig.create(ten_env=ten_env) + ten_env.log_info(f"config: {self.config}") - try: - self.deepgram_config.api_key = ten_env.get_property_string(PROPERTY_API_KEY).strip() - except Exception as e: - ten_env.log_error(f"get property {PROPERTY_API_KEY} error: {e}") + if not self.config.api_key: + ten_env.log_error(f"get property api_key") return - for optional_param in [ - PROPERTY_LANG, - PROPERTY_MODEL, - PROPERTY_SAMPLE_RATE, - ]: - try: - value = ten_env.get_property_string(optional_param).strip() - if value: - self.deepgram_config.__setattr__(optional_param, value) - except Exception as err: - ten_env.log_debug( - f"get property optional {optional_param} failed, err: {err}. Using default value: {self.deepgram_config.__getattribute__(optional_param)}" - ) - - self.deepgram_client = AsyncListenWebSocketClient(config=DeepgramClientOptions( - api_key=self.deepgram_config.api_key, + self.client = AsyncListenWebSocketClient(config=DeepgramClientOptions( + api_key=self.config.api_key, options={"keepalive": "true"} )) @@ -80,12 +73,12 @@ async def on_audio_frame(self, ten_env: AsyncTenEnv, frame: AudioFrame) -> None: return self.stream_id = frame.get_property_int('stream_id') - await self.deepgram_client.send(frame_buf) + await self.client.send(frame_buf) async def on_stop(self, ten_env: AsyncTenEnv) -> None: ten_env.log_info("on_stop") - await self.deepgram_client.finish() + await self.client.finish() async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: cmd_json = cmd.to_json() @@ -118,22 +111,22 @@ async def on_message(_, result, **kwargs): async def on_error(_, error, **kwargs): self.ten_env.log_error(f"deepgram event callback on_error: {error}") - self.deepgram_client.on(LiveTranscriptionEvents.Open, on_open) - self.deepgram_client.on(LiveTranscriptionEvents.Close, on_close) - self.deepgram_client.on(LiveTranscriptionEvents.Transcript, on_message) - self.deepgram_client.on(LiveTranscriptionEvents.Error, on_error) - - options = LiveOptions(language=self.deepgram_config.language, - model=self.deepgram_config.model, - sample_rate=self.deepgram_config.sample_rate, - channels=self.deepgram_config.channels, - encoding=self.deepgram_config.encoding, - interim_results=self.deepgram_config.interim_results, - punctuate=self.deepgram_config.punctuate) + self.client.on(LiveTranscriptionEvents.Open, on_open) + self.client.on(LiveTranscriptionEvents.Close, on_close) + self.client.on(LiveTranscriptionEvents.Transcript, on_message) + self.client.on(LiveTranscriptionEvents.Error, on_error) + + options = LiveOptions(language=self.config.language, + model=self.config.model, + sample_rate=self.config.sample_rate, + channels=self.config.channels, + encoding=self.config.encoding, + interim_results=self.config.interim_results, + punctuate=self.config.punctuate) # connect to websocket - result = await self.deepgram_client.start(options) + result = await self.client.start(options) if result is False: - if self.deepgram_client.status_code == 402: + if self.client.status_code == 402: self.ten_env.log_error("Failed to connect to Deepgram - your account has run out of credits.") else: self.ten_env.log_error("Failed to connect to Deepgram") diff --git a/agents/ten_packages/extension/deepgram_asr_python/manifest.json b/agents/ten_packages/extension/deepgram_asr_python/manifest.json index e7914dd6..9e3298f6 100644 --- a/agents/ten_packages/extension/deepgram_asr_python/manifest.json +++ b/agents/ten_packages/extension/deepgram_asr_python/manifest.json @@ -26,8 +26,7 @@ }, "audio_frame_in": [ { - "name": "pcm_frame", - "property": {} + "name": "pcm_frame" } ], "cmd_in": [ diff --git a/agents/ten_packages/extension/glue_python_async/__init__.py b/agents/ten_packages/extension/glue_python_async/__init__.py index b5acaf2d..c22fdd7c 100644 --- a/agents/ten_packages/extension/glue_python_async/__init__.py +++ b/agents/ten_packages/extension/glue_python_async/__init__.py @@ -4,6 +4,4 @@ # See the LICENSE file for more information. # from . import addon -from .log import logger -logger.info("glue_python_async extension loaded") diff --git a/agents/ten_packages/extension/glue_python_async/addon.py b/agents/ten_packages/extension/glue_python_async/addon.py index 107c7394..3f33c4c2 100644 --- a/agents/ten_packages/extension/glue_python_async/addon.py +++ b/agents/ten_packages/extension/glue_python_async/addon.py @@ -15,6 +15,5 @@ class AsyncGlueExtensionAddon(Addon): def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: from .extension import AsyncGlueExtension - from .log import logger - logger.info("AsyncGlueExtensionAddon on_create_instance") + ten_env.log_info("AsyncGlueExtensionAddon on_create_instance") ten_env.on_create_instance_done(AsyncGlueExtension(name), context) diff --git a/agents/ten_packages/extension/glue_python_async/examples/openai_wrapper.py b/agents/ten_packages/extension/glue_python_async/examples/openai_wrapper.py index b9299369..8cde92f7 100644 --- a/agents/ten_packages/extension/glue_python_async/examples/openai_wrapper.py +++ b/agents/ten_packages/extension/glue_python_async/examples/openai_wrapper.py @@ -1,16 +1,51 @@ import os import openai +import json from openai import AsyncOpenAI -import traceback # Add this import +import traceback +import logging +import logging.config -from typing import List, Union -from pydantic import BaseModel, HttpUrl +from typing import List, Union, Dict, Optional +from pydantic import BaseModel, HttpUrl, ValidationError from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from fastapi import Depends, FastAPI, HTTPException, Request import asyncio +# Enable Pydantic debug mode +from pydantic import BaseConfig + +BaseConfig.debug = True + +# Set up logging +logging.config.dictConfig({ + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "file": { + "level": "DEBUG", + "formatter": "default", + "class": "logging.FileHandler", + "filename": "example.log", + }, + }, + "loggers": { + "": { + "handlers": ["file"], + "level": "DEBUG", + "propagate": True, + }, + }, +}) +logger = logging.getLogger(__name__) + app = FastAPI(title="Chat Completion API", description="API for streaming chat completions with support for text, image, and audio content", version="1.0.0") @@ -27,79 +62,109 @@ class ImageContent(BaseModel): image_url: HttpUrl class AudioContent(BaseModel): - type: str = "audio" - audio_url: HttpUrl - -class Message(BaseModel): - role: str - content: Union[TextContent, ImageContent, AudioContent, List[Union[TextContent, ImageContent, AudioContent]]] + type: str = "input_audio" + input_audio: Dict[str, str] + +class ToolFunction(BaseModel): + name: str + description: Optional[str] + parameters: Optional[Dict] + strict: bool = False + +class Tool(BaseModel): + type: str = "function" + function: ToolFunction + +class ToolChoice(BaseModel): + type: str = "function" + function: Optional[Dict] + +class ResponseFormat(BaseModel): + type: str = "json_schema" + json_schema: Optional[Dict[str, str]] + +class SystemMessage(BaseModel): + role: str = "system" + content: Union[str, List[str]] + +class UserMessage(BaseModel): + role: str = "user" + content: Union[str, List[Union[TextContent, ImageContent, AudioContent]]] + +class AssistantMessage(BaseModel): + role: str = "assistant" + content: Union[str, List[TextContent]] = None + audio: Optional[Dict[str, str]] = None + tool_calls: Optional[List[Dict]] = None + +class ToolMessage(BaseModel): + role: str = "tool" + content: Union[str, List[str]] + tool_call_id: str class ChatCompletionRequest(BaseModel): - messages: List[Message] - model: str - temperature: float = 1.0 + context: Optional[Dict] = None + model: Optional[str] = None + messages: List[Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]] + response_format: Optional[ResponseFormat] = None + modalities: List[str] = ["text"] + audio: Optional[Dict[str, str]] = None + tools: Optional[List[Tool]] = None + tool_choice: Optional[Union[str, ToolChoice]] = "auto" + parallel_tool_calls: bool = True stream: bool = True + stream_options: Optional[Dict] = None -def format_openai_messages(messages): - formatted_messages = [] - for msg in messages: - if isinstance(msg.content, list): - content = [] - for item in msg.content: - if item.type == "text": - content.append({"type": "text", "text": item.text}) - elif item.type == "image": - content.append({"type": "image_url", "image_url": str(item.image_url)}) - elif item.type == "audio": - content.append({"type": "audio_url", "audio_url": str(item.audio_url)}) - else: - if msg.content.type == "text": - content = [{"type": "text", "text": msg.content.text}] - elif msg.content.type == "image": - content = [{"type": "image_url", "image_url": {"url": str(msg.content.image_url)}}] - elif msg.content.type == "audio": - content = [{"type": "audio_url", "audio_url": {"url": str(msg.content.audio_url)}}] - - formatted_messages.append({"role": msg.role, "content": content}) - return formatted_messages +''' +{'messages': [{'role': 'user', 'content': 'Hello. Hello. Hello.'}, {'role': 'user', 'content': 'Unprocessedable.'}], 'tools': [], 'tools_choice': 'none', 'model': 'gpt-3.5-turbo', 'stream': True} +''' security = HTTPBearer() def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): token = credentials.credentials if token != os.getenv("API_TOKEN"): + logger.warning("Invalid or missing token") raise HTTPException(status_code=403, detail="Invalid or missing token") @app.post("/chat/completions", dependencies=[Depends(verify_token)]) async def create_chat_completion(request: ChatCompletionRequest, req: Request): try: - messages = format_openai_messages(request.messages) + logger.debug(f"Received request: {request.json()}") client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = await client.chat.completions.create( model=request.model, - messages=messages, - temperature=request.temperature, - stream=request.stream + messages=request.messages, # Directly use request messages + tool_choice=request.tool_choice if request.tools and request.tool_choice else None, + tools=request.tools if request.tools else None, + # modalities=request.modalities, + response_format=request.response_format, + stream=request.stream, + stream_options=request.stream_options ) - async def generate(): - try: - async for chunk in response: - if chunk.choices[0].delta.content is not None: - yield f"data: {chunk.choices[0].delta.content}\n\n" - yield "data: [DONE]\n\n" - except asyncio.CancelledError: - print("Request was cancelled") - raise - - return StreamingResponse(generate(), media_type="text/event-stream") + if request.stream: + async def generate(): + try: + async for chunk in response: + logger.info(f"Received chunk: {chunk}") + yield f"data: {json.dumps(chunk.to_dict())}\n\n" + yield "data: [DONE]\n\n" + except asyncio.CancelledError: + logger.info("Request was cancelled") + raise + + return StreamingResponse(generate(), media_type="text/event-stream") + else: + result = await response + return result except asyncio.CancelledError: - print("Request was cancelled") + logger.info("Request was cancelled") raise HTTPException(status_code=499, detail="Request was cancelled") except Exception as e: traceback_str = ''.join(traceback.format_tb(e.__traceback__)) error_message = f"{str(e)}\n{traceback_str}" - print(error_message) + logger.error(error_message) raise HTTPException(status_code=500, detail=error_message) if __name__ == "__main__": diff --git a/agents/ten_packages/extension/glue_python_async/extension.py b/agents/ten_packages/extension/glue_python_async/extension.py index 2fa8cc3d..ceb9d235 100644 --- a/agents/ten_packages/extension/glue_python_async/extension.py +++ b/agents/ten_packages/extension/glue_python_async/extension.py @@ -6,14 +6,16 @@ import asyncio import traceback import aiohttp +import json from datetime import datetime -from typing import List +from typing import List, Any, AsyncGenerator +from dataclasses import dataclass +from pydantic import BaseModel from ten import ( AudioFrame, VideoFrame, - AsyncExtension, AsyncTenEnv, Cmd, StatusCode, @@ -21,10 +23,15 @@ Data, ) -PROPERTY_API_URL = "api_url" -PROPERTY_USER_ID = "user_id" -PROPERTY_PROMPT = "prompt" -PROPERTY_TOKEN = "token" +from ten_ai_base import BaseConfig, ChatMemory +from ten_ai_base.llm import AsyncLLMBaseExtension, LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata +from ten_ai_base.types import LLMChatCompletionUserMessageParam, LLMToolResult + +CMD_IN_FLUSH = "flush" +CMD_IN_ON_USER_JOINED = "on_user_joined" +CMD_IN_ON_USER_LEFT = "on_user_left" +CMD_OUT_FLUSH = "flush" +CMD_OUT_TOOL_CALL = "tool_call" DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" @@ -32,6 +39,8 @@ DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment" +CMD_PROPERTY_RESULT = "tool_result" + def is_punctuation(char): if char in [",", ",", ".", "。", "?", "?", "!", "!"]: return True @@ -51,87 +60,249 @@ def parse_sentences(sentence_fragment, content): remain = current_sentence return sentences, remain -class AsyncGlueExtension(AsyncExtension): +class ToolCallFunction(BaseModel): + name: str | None = None + arguments: str | None = None + +class ToolCall(BaseModel): + index: int + type: str = "function" + id: str | None = None + function: ToolCallFunction + +class ToolCallResponse(BaseModel): + id: str + response: LLMToolResult + error: str | None = None + +class Delta(BaseModel): + content: str | None = None + tool_calls: List[ToolCall] = None + +class Choice(BaseModel): + delta: Delta = None + index: int + finish_reason: str | None + +class CompletionTokensDetails(BaseModel): + accepted_prediction_tokens: int = 0 + audio_tokens: int = 0 + reasoning_tokens: int = 0 + rejected_prediction_tokens: int = 0 + +class PromptTokensDetails(BaseModel): + audio_tokens: int = 0 + cached_tokens: int = 0 + +class Usage(BaseModel): + completion_tokens: int = 0 + prompt_tokens: int = 0 + total_tokens: int = 0 + + completion_tokens_details: CompletionTokensDetails | None = None + prompt_tokens_details: PromptTokensDetails | None = None + +class ResponseChunk(BaseModel): + choices: List[Choice] + usage: Usage | None = None + +@dataclass +class GlueConfig(BaseConfig): api_url: str = "http://localhost:8000/chat/completions" - user_id: str = "TenAgent" - prompt: str = "" token: str = "" - outdate_ts = datetime.now() + prompt: str = "" + max_history: int = 10 + greeting: str = "" + +class AsyncGlueExtension(AsyncLLMBaseExtension): + config : GlueConfig = None sentence_fragment: str = "" ten_env: AsyncTenEnv = None loop: asyncio.AbstractEventLoop = None stopped: bool = False - queue = asyncio.Queue() - history: List[dict] = [] - max_history: int = 10 - session: aiohttp.ClientSession = None + memory: ChatMemory = None + total_usage: Usage = Usage() + users_count = 0 async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) ten_env.log_debug("on_init") - ten_env.on_init_done() async def on_start(self, ten_env: AsyncTenEnv) -> None: + await super().on_start(ten_env) ten_env.log_debug("on_start") self.loop = asyncio.get_event_loop() - try: - self.api_url = ten_env.get_property_string(PROPERTY_API_URL) - except Exception as err: - ten_env.log_error(f"GetProperty optional {PROPERTY_API_URL} failed, err: {err}") - return + self.config = GlueConfig.create(ten_env=ten_env) + ten_env.log_info(f"config: {self.config}") - try: - self.user_id = ten_env.get_property_string(PROPERTY_USER_ID) - except Exception as err: - ten_env.log_error(f"GetProperty optional {PROPERTY_USER_ID} failed, err: {err}") - - try: - self.prompt = ten_env.get_property_string(PROPERTY_PROMPT) - except Exception as err: - ten_env.log_error(f"GetProperty optional {PROPERTY_PROMPT} failed, err: {err}") - - try: - self.token = ten_env.get_property_string(PROPERTY_TOKEN) - except Exception as err: - ten_env.log_error(f"GetProperty optional {PROPERTY_TOKEN} failed, err: {err}") - - try: - self.max_history = ten_env.get_property_int("max_memory_length") - except Exception as err: - ten_env.log_error(f"GetProperty optional max_memory_length failed, err: {err}") + self.memory = ChatMemory(self.config.max_history) self.ten_env = ten_env - self.loop.create_task(self._consume()) async def on_stop(self, ten_env: AsyncTenEnv) -> None: + await super().on_stop(ten_env) ten_env.log_debug("on_stop") self.stopped = True await self.queue.put(None) - await self._flush() async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) ten_env.log_debug("on_deinit") async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: cmd_name = cmd.get_name() ten_env.log_debug("on_cmd name {}".format(cmd_name)) - if cmd_name == "flush": - try: - await self._flush() - await ten_env.send_cmd(Cmd.create("flush")) - ten_env.log_info("on flush") - except Exception as e: - ten_env.log_error(f"{traceback.format_exc()} \n Failed to handle {e}") + status = StatusCode.OK + detail = "success" + + if cmd_name == CMD_IN_FLUSH: + await self.flush_input_items(ten_env) + await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) + ten_env.log_info("on flush") + elif cmd_name == CMD_IN_ON_USER_JOINED: + self.users_count += 1 + # Send greeting when first user joined + if self.config.greeting and self.users_count == 1: + self.send_text_output(ten_env, self.config.greeting, True) + elif cmd_name == CMD_IN_ON_USER_LEFT: + self.users_count -= 1 + else: + await super().on_cmd(ten_env, cmd) + return - cmd_result = CmdResult.create(StatusCode.OK) + cmd_result = CmdResult.create(status) + cmd_result.set_property_string("detail", detail) ten_env.return_result(cmd_result, cmd) + async def on_call_chat_completion(self, ten_env: AsyncTenEnv, **kargs: LLMCallCompletionArgs) -> any: + raise Exception("Not implemented") + + async def on_data_chat_completion(self, ten_env: AsyncTenEnv, **kargs: LLMDataCompletionArgs) -> None: + input: LLMChatCompletionUserMessageParam = kargs.get("messages", []) + + messages = [] + if self.config.prompt: + messages.append({"role": "system", "content": self.config.prompt}) + messages.extend(self.memory.get()) + if not input: + ten_env.log_warn("No message in data") + else: + messages.extend(input) + for i in input: + self.memory.put(i) + + def tool_dict(tool: LLMToolMetadata): + json = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False + }, + }, + "strict": True + } + + for param in tool.parameters: + json["function"]["parameters"]["properties"][param.name] = { + "type": param.type, + "description": param.description + } + if param.required: + json["function"]["parameters"]["required"].append(param.name) + + return json + tools = [] + for tool in self.available_tools: + tools.append(tool_dict(tool)) + + total_output = "" + sentence_fragment = "" + calls = {} + + sentences = [] + response = self._stream_chat(messages=messages, tools=tools) + async for message in response: + self.ten_env.log_info(f"content: {message}") + # TODO: handle tool call + try: + c = ResponseChunk(**message) + if c.choices: + if c.choices[0].delta.content: + total_output += c.choices[0].delta.content + sentences, sentence_fragment = parse_sentences(sentence_fragment, c.choices[0].delta.content) + for s in sentences: + await self._send_text(s) + if c.choices[0].delta.tool_calls: + self.ten_env.log_info(f"tool_calls: {c.choices[0].delta.tool_calls}") + for call in c.choices[0].delta.tool_calls: + if call.index not in calls: + calls[call.index] = ToolCall(id=call.id, index=call.index, function=ToolCallFunction(name="", arguments="")) + if call.function.name: + calls[call.index].function.name += call.function.name + if call.function.arguments: + calls[call.index].function.arguments += call.function.arguments + if c.usage: + self.ten_env.log_info(f"usage: {c.usage}") + self._update_usage(c.usage) + except Exception as e: + self.ten_env.log_error(f"Failed to parse response: {message} {e}") + traceback.print_exc() + + if total_output: + self.memory.put({"role": "assistant", "content": total_output}) + + if calls: + tasks = [] + tool_calls = [] + for _, call in calls.items(): + self.ten_env.log_info(f"tool call: {call}") + tool_calls.append(call.model_dump()) + tasks.append(self.handle_tool_call(call)) + self.memory.put({"role": "assistant", "tool_calls": tool_calls}) + responses = await asyncio.gather(*tasks) + for r in responses: + content = r.response["content"] + self.ten_env.log_info(f"tool call response: {content} {r.id}") + self.memory.put({"role": "tool", "content": json.dumps(content), "tool_call_id": r.id}) + + # request again to let the model know the tool call results + await self.on_data_chat_completion(ten_env) + + self.ten_env.log_info(f"total_output: {total_output} {calls}") + + async def on_tools_update(self, ten_env: AsyncTenEnv, tool: LLMToolMetadata) -> None: + # Implement the logic for tool updates + return await super().on_tools_update(ten_env, tool) + + async def handle_tool_call(self, call: ToolCall) -> ToolCallResponse: + cmd: Cmd = Cmd.create(CMD_OUT_TOOL_CALL) + cmd.set_property_string("name", call.function.name) + cmd.set_property_from_json("arguments", call.function.arguments) + + # Send the command and handle the result through the future + result: CmdResult = await self.ten_env.send_cmd(cmd) + if result.get_status_code() == StatusCode.OK: + tool_result: LLMToolResult = json.loads( + result.get_property_to_json(CMD_PROPERTY_RESULT)) + + self.ten_env.log_info(f"tool_result: {call} {tool_result}") + return ToolCallResponse(id=call.id, response=tool_result) + else: + self.ten_env.log_error(f"Tool call failed") + return ToolCallResponse(id=call.id, error=f"Tool call failed with status code {result.get_status_code()}") + async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: data_name = data.get_name() - ten_env.log_debug("on_data name {}".format(data_name)) + ten_env.log_info("on_data name {}".format(data_name)) is_final = False input_text = "" @@ -154,8 +325,10 @@ async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: ten_env.log_info(f"OnData input text: [{input_text}]") - ts = datetime.now() - await self.queue.put((input_text, ts)) + # Start an asynchronous task for handling chat completion + message = LLMChatCompletionUserMessageParam( + role="user", content=input_text) + await self.queue_input_item(False, messages=[message]) async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> None: pass @@ -163,75 +336,36 @@ async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> async def on_video_frame(self, ten_env: AsyncTenEnv, video_frame: VideoFrame) -> None: pass - async def _flush(self): - # self.ten_env.log_info("flush") - self.outdate_ts = datetime.now() - if self.session: - await self.session.close() - self.session = None - - def _need_interrrupt(self, ts: datetime) -> bool: - return self.outdate_ts > ts - async def _send_text(self, text: str) -> None: data = Data.create("text_data") data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text) data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT, True) self.ten_env.send_data(data) - async def _consume(self) -> None: - self.ten_env.log_info("start async loop") - while not self.stopped: - try: - value = await self.queue.get() - if value is None: - self.ten_env.log_info("async loop exit") - break - input, ts = value - if self._need_interrrupt(ts): - continue - - await self._chat(input, ts) - except Exception as e: - self.ten_env.log_error(f"Failed to handle {e}") - - async def _add_to_history(self, role: str, content: str) -> None: - self.history.append({"role": role, "content": content}) - if len(self.history) > self.max_history: - self.history = self.history[1:] - - async def _get_messages(self) -> List[dict]: - messages = [] - if self.prompt: - messages.append({"role": "system", "content": self.prompt}) - messages.extend(self.history) - return messages - - async def _chat(self, input: str, ts: datetime) -> None: - self.session = aiohttp.ClientSession() + async def _stream_chat(self, messages: List[Any], tools: List[Any]) -> AsyncGenerator[dict, None]: + session = aiohttp.ClientSession() try: - messages = await self._get_messages() - messages.append({"role": "user", "content": input}) - await self._add_to_history("user", input) payload = { - "messages": [{"role": msg["role"], "content": {"type": "text", "text": msg["content"]}} for msg in messages], + "messages": messages, + "tools": tools, + "tools_choice": "auto" if tools else "none", "model": "gpt-3.5-turbo", - "temperature": 1.0, - "stream": True + "stream": True, + "stream_options": {"include_usage": True} } - self.ten_env.log_info(f"payload before sending: {payload}") + self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}") headers = { - "Authorization": f"Bearer {self.token}", + "Authorization": f"Bearer {self.config.token}", "Content-Type": "application/json" } - total_output = "" - async with self.session.post(self.api_url, json=payload, headers=headers) as response: + + async with session.post(self.config.api_url, json=payload, headers=headers) as response: + if response.status != 200: + r = await response.json() + self.ten_env.log_error(f"Received unexpected status {r} from the server.") + return + async for line in response.content: - if self._need_interrrupt(ts): - self.ten_env.log_info("interrupted") - total_output += "[interrupted]" - break - if line: l = line.decode('utf-8').strip() if l.startswith("data:"): @@ -239,15 +373,26 @@ async def _chat(self, input: str, ts: datetime) -> None: if content == "[DONE]": break self.ten_env.log_info(f"content: {content}") - sentences, self.sentence_fragment = parse_sentences(self.sentence_fragment, content) - for s in sentences: - await self._send_text(s) - total_output += s - self.ten_env.log_info(f"total_output: {total_output}") - await self._add_to_history("assistant", total_output) + yield json.loads(content) except Exception as e: - traceback.print_exc() self.ten_env.log_error(f"Failed to handle {e}") finally: - await self.session.close() - self.session = None + await session.close() + session = None + + async def _update_usage(self, usage: Usage) -> None: + self.total_usage.completion_tokens += usage.completion_tokens + self.total_usage.prompt_tokens += usage.prompt_tokens + self.total_usage.total_tokens += usage.total_tokens + + if usage.completion_tokens_details: + self.total_usage.completion_tokens_details.accepted_prediction_tokens += usage.completion_tokens_details.accepted_prediction_tokens + self.total_usage.completion_tokens_details.audio_tokens += usage.completion_tokens_details.audio_tokens + self.total_usage.completion_tokens_details.reasoning_tokens += usage.completion_tokens_details.reasoning_tokens + self.total_usage.completion_tokens_details.rejected_prediction_tokens += usage.completion_tokens_details.rejected_prediction_tokens + + if usage.prompt_tokens_details: + self.total_usage.prompt_tokens_details.audio_tokens += usage.prompt_tokens_details.audio_tokens + self.total_usage.prompt_tokens_details.cached_tokens += usage.prompt_tokens_details.cached_tokens + + self.ten_env.log_info(f"total usage: {self.total_usage}") \ No newline at end of file diff --git a/agents/ten_packages/extension/glue_python_async/log.py b/agents/ten_packages/extension/glue_python_async/log.py deleted file mode 100644 index 5e7ec888..00000000 --- a/agents/ten_packages/extension/glue_python_async/log.py +++ /dev/null @@ -1,20 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file for more information. -# -import logging - -logger = logging.getLogger("glue_python_async") -logger.setLevel(logging.INFO) - -formatter_str = ( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - " - "[%(filename)s:%(lineno)d] - %(message)s" -) -formatter = logging.Formatter(formatter_str) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/glue_python_async/schema.yml b/agents/ten_packages/extension/glue_python_async/schema.yml index ce885c7b..099fdff5 100644 --- a/agents/ten_packages/extension/glue_python_async/schema.yml +++ b/agents/ten_packages/extension/glue_python_async/schema.yml @@ -15,14 +15,16 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/ChatCompletionRequest' + $ref: "#/components/schemas/ChatCompletionRequest" responses: - '200': + "200": description: Successful response content: application/json: schema: - $ref: '#/components/schemas/ChatCompletionResponse' + oneOf: + - $ref: "#/components/schemas/ChatCompletionResponse" + - $ref: "#/components/schemas/ChatCompletionChunk" x-stream: true components: @@ -31,45 +33,150 @@ components: type: object required: - messages - - model properties: + context: + type: object + model: + type: string + example: "gpt-4o" messages: type: array items: - $ref: '#/components/schemas/Message' - model: - type: string - example: "gpt-4-vision-preview" - temperature: - type: number - format: float - minimum: 0 - maximum: 2 - default: 1 + oneOf: + - $ref: "#/components/schemas/SystemMessage" + - $ref: "#/components/schemas/UserMessage" + - $ref: "#/components/schemas/AssistantMessage" + - $ref: "#/components/schemas/ToolMessage" + response_format: + $ref: "#/components/schemas/ResponseFormat" + modalities: + type: array + items: + type: string + enum: [text, audio] + default: [text] + audio: + type: object + properties: + voice: + type: string + tools: + type: array + items: + $ref: "#/components/schemas/Tool" + tool_choice: + oneOf: + - type: string + enum: [auto, none, required] + - type: object + properties: + type: + type: string + enum: [function] + function: + type: object + parallel_tool_calls: + type: boolean + default: true stream: type: boolean default: true - Message: + SystemMessage: type: object required: - role - content properties: + name: + type: string role: type: string - enum: [system, user, assistant] + enum: [system] content: oneOf: - - $ref: '#/components/schemas/TextContent' - - $ref: '#/components/schemas/ImageContent' - - $ref: '#/components/schemas/AudioContent' + - type: string + - type: array + items: + type: string + + UserMessage: + type: object + required: + - role + - content + properties: + name: + type: string + role: + type: string + enum: [user] + content: + oneOf: + - type: string - type: array items: oneOf: - - $ref: '#/components/schemas/TextContent' - - $ref: '#/components/schemas/ImageContent' - - $ref: '#/components/schemas/AudioContent' + - $ref: "#/components/schemas/TextContent" + - $ref: "#/components/schemas/ImageContent" + - $ref: "#/components/schemas/AudioContent" + + AssistantMessage: + type: object + required: + - role + - content + properties: + name: + type: string + role: + type: string + enum: [system] + audio: + type: object + properties: + id: + type: string + content: + oneOf: + - type: string + - type: array + items: + $ref: "#/components/schemas/TextContent" + tool_calls: + type: object + properties: + id: + type: string + type: + type: string + enum: [function] + function: + type: object + properties: + name: + type: string + arguments: + type: string + + ToolMessage: + type: object + required: + - role + - content + - tool_call_id + properties: + role: + type: string + enum: [tool] + content: + oneOf: + - type: string + - type: array + items: + type: string + tool_call_id: + type: string TextContent: type: object @@ -91,7 +198,7 @@ components: properties: type: type: string - enum: [image] + enum: [image_url] image_url: type: string format: uri @@ -100,14 +207,53 @@ components: type: object required: - type - - audio_url + - input_audio properties: type: type: string - enum: [audio] - audio_url: + enum: [input_audio] + input_audio: + type: object + properties: + data: + type: string + format: + type: string + + Tool: + type: object + properties: + type: type: string - format: uri + enum: [function] + function: + type: object + required: + - name + properties: + name: + type: string + description: + type: string + parameters: + type: object + strict: + type: boolean + default: false + + ResponseFormat: + type: object + properties: + type: + type: string + enum: [json_schema] + json_schema: + type: object + properties: + name: + type: string + schema: + type: object ChatCompletionResponse: type: object @@ -120,24 +266,123 @@ components: type: integer model: type: string + usage: + $ref: "#/components/schemas/Usage" choices: type: array items: - $ref: '#/components/schemas/Choice' + $ref: "#/components/schemas/Choice" + + ChatCompletionChunk: + type: object + properties: + id: + type: string + object: + type: string + created: + type: integer + model: + type: string + usage: + $ref: "#/components/schemas/Usage" + choices: + type: array + items: + $ref: "#/components/schemas/DeltaChoice" + + Usage: + type: object + properties: + completion_tokens: + type: integer + prompt_tokens: + type: integer + total_tokens: + type: integer + completion_tokens_details: + type: object + properties: + accepted_prediction_tokens: + type: integer + audio_tokens: + type: integer + reasoning_tokens: + type: integer + rejected_prediction_tokens: + type: integer + prompt_tokens_details: + type: object + properties: + audio_tokens: + type: integer + cached_tokens: + type: integer Choice: + type: object + properties: + message: + $ref: "#/components/schemas/ResponseMessage" + index: + type: integer + finish_reason: + type: string + + DeltaChoice: type: object properties: delta: - $ref: '#/components/schemas/Delta' + $ref: "#/components/schemas/ResponseMessage" index: type: integer finish_reason: type: string - nullable: true Delta: type: object properties: content: - type: string \ No newline at end of file + type: string + + ResponseMessage: + type: object + properties: + content: + type: string + refusal: + type: string + tool_calls: + $ref: "#/components/schemas/ToolCall" + role: + type: string + audio: + $ref: "#/components/schemas/Audio" + + ToolCall: + type: object + properties: + id: + type: string + type: + type: string + enum: [function] + function: + type: object + properties: + name: + type: string + arguments: + type: string + + Audio: + type: object + properties: + id: + type: string + expires_at: + type: integer + data: + type: string + transcript: + type: string diff --git a/agents/ten_packages/extension/interrupt_detector_python/__init__.py b/agents/ten_packages/extension/interrupt_detector_python/__init__.py index 8692cc02..bdddf6e7 100644 --- a/agents/ten_packages/extension/interrupt_detector_python/__init__.py +++ b/agents/ten_packages/extension/interrupt_detector_python/__init__.py @@ -1,4 +1 @@ -from . import interrupt_detector_addon -from .log import logger - -logger.info("interrupt_detector_python extension loaded") +from . import addon \ No newline at end of file diff --git a/agents/ten_packages/extension/interrupt_detector_python/interrupt_detector_addon.py b/agents/ten_packages/extension/interrupt_detector_python/addon.py similarity index 75% rename from agents/ten_packages/extension/interrupt_detector_python/interrupt_detector_addon.py rename to agents/ten_packages/extension/interrupt_detector_python/addon.py index 0de07b5a..4090b4f3 100644 --- a/agents/ten_packages/extension/interrupt_detector_python/interrupt_detector_addon.py +++ b/agents/ten_packages/extension/interrupt_detector_python/addon.py @@ -12,13 +12,11 @@ TenEnv, ) - @register_addon_as_extension("interrupt_detector_python") class InterruptDetectorExtensionAddon(Addon): def on_create_instance(self, ten: TenEnv, addon_name: str, context) -> None: - from .log import logger - logger.info("on_create_instance") + ten.log_info("on_create_instance") - from .interrupt_detector_extension import InterruptDetectorExtension + from .extension import InterruptDetectorExtension ten.on_create_instance_done(InterruptDetectorExtension(addon_name), context) diff --git a/agents/ten_packages/extension/interrupt_detector_python/interrupt_detector_extension.py b/agents/ten_packages/extension/interrupt_detector_python/extension.py similarity index 83% rename from agents/ten_packages/extension/interrupt_detector_python/interrupt_detector_extension.py rename to agents/ten_packages/extension/interrupt_detector_python/extension.py index 6710d8dc..7ddb310e 100644 --- a/agents/ten_packages/extension/interrupt_detector_python/interrupt_detector_extension.py +++ b/agents/ten_packages/extension/interrupt_detector_python/extension.py @@ -14,36 +14,33 @@ StatusCode, CmdResult, ) -from .log import logger - CMD_NAME_FLUSH = "flush" TEXT_DATA_TEXT_FIELD = "text" TEXT_DATA_FINAL_FIELD = "is_final" - class InterruptDetectorExtension(Extension): def on_start(self, ten: TenEnv) -> None: - logger.info("on_start") + ten.log_info("on_start") ten.on_start_done() def on_stop(self, ten: TenEnv) -> None: - logger.info("on_stop") + ten.log_info("on_stop") ten.on_stop_done() def send_flush_cmd(self, ten: TenEnv) -> None: flush_cmd = Cmd.create(CMD_NAME_FLUSH) ten.send_cmd( flush_cmd, - lambda ten, result: logger.info("send_cmd done"), + lambda ten, result: ten.log_info("send_cmd done"), ) - logger.info(f"sent cmd: {CMD_NAME_FLUSH}") + ten.log_info(f"sent cmd: {CMD_NAME_FLUSH}") def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: cmd_name = cmd.get_name() - logger.info("on_cmd name {}".format(cmd_name)) + ten.log_info("on_cmd name {}".format(cmd_name)) # flush whatever cmd incoming at the moment self.send_flush_cmd(ten) @@ -53,7 +50,7 @@ def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: new_cmd = Cmd.create_from_json(cmd_json) ten.send_cmd( new_cmd, - lambda ten, result: logger.info("send_cmd done"), + lambda ten, result: ten.log_info("send_cmd done"), ) cmd_result = CmdResult.create(StatusCode.OK) @@ -67,12 +64,12 @@ def on_data(self, ten: TenEnv, data: Data) -> None: example: {name: text_data, properties: {text: "hello", is_final: false} """ - logger.info(f"on_data") + ten.log_info(f"on_data") try: text = data.get_property_string(TEXT_DATA_TEXT_FIELD) except Exception as e: - logger.warning( + ten.log_warn( f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}" ) return @@ -80,12 +77,12 @@ def on_data(self, ten: TenEnv, data: Data) -> None: try: final = data.get_property_bool(TEXT_DATA_FINAL_FIELD) except Exception as e: - logger.warning( + ten.log_warn( f"on_data get_property_bool {TEXT_DATA_FINAL_FIELD} error: {e}" ) return - logger.debug( + ten.log_debug( f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final}" ) diff --git a/agents/ten_packages/extension/interrupt_detector_python/log.py b/agents/ten_packages/extension/interrupt_detector_python/log.py deleted file mode 100644 index 303d06e1..00000000 --- a/agents/ten_packages/extension/interrupt_detector_python/log.py +++ /dev/null @@ -1,13 +0,0 @@ -import logging - -logger = logging.getLogger("interrupt_detector_python") -logger.setLevel(logging.INFO) - -formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s" -) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/message_collector_rtm/__init__.py b/agents/ten_packages/extension/message_collector_rtm/__init__.py index 3d42c9b5..645dc801 100644 --- a/agents/ten_packages/extension/message_collector_rtm/__init__.py +++ b/agents/ten_packages/extension/message_collector_rtm/__init__.py @@ -6,6 +6,3 @@ # # from .src import addon -from .src.log import logger - -logger.info("message_collector_rtm extension loaded") diff --git a/agents/ten_packages/extension/message_collector_rtm/manifest.json b/agents/ten_packages/extension/message_collector_rtm/manifest.json index 506f2e01..f5b2a335 100644 --- a/agents/ten_packages/extension/message_collector_rtm/manifest.json +++ b/agents/ten_packages/extension/message_collector_rtm/manifest.json @@ -48,18 +48,6 @@ "type": "string" } } - }, - { - "name": "rtm_storage_event", - "property": {} - }, - { - "name": "rtm_presence_event", - "property": {} - }, - { - "name": "rtm_lock_event", - "property": {} } ], "data_out": [ @@ -75,12 +63,6 @@ } } ], - "cmd_in": [ - { - "name": "on_user_audio_track_state_changed", - "property": {} - } - ], "cmd_out": [ { "name": "publish", diff --git a/agents/ten_packages/extension/message_collector_rtm/src/addon.py b/agents/ten_packages/extension/message_collector_rtm/src/addon.py index bd53f206..1602995b 100644 --- a/agents/ten_packages/extension/message_collector_rtm/src/addon.py +++ b/agents/ten_packages/extension/message_collector_rtm/src/addon.py @@ -17,6 +17,5 @@ class MessageCollectorRTMExtension(Addon): def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: from .extension import MessageCollectorRTMExtension - from .log import logger - logger.info("MessageCollectorRTMExtensionAddon on_create_instance") + ten_env.log_info("MessageCollectorRTMExtensionAddon on_create_instance") ten_env.on_create_instance_done(MessageCollectorRTMExtension(name), context) diff --git a/agents/ten_packages/extension/message_collector_rtm/src/extension.py b/agents/ten_packages/extension/message_collector_rtm/src/extension.py index 04e3a809..5b821ed7 100644 --- a/agents/ten_packages/extension/message_collector_rtm/src/extension.py +++ b/agents/ten_packages/extension/message_collector_rtm/src/extension.py @@ -5,11 +5,11 @@ # Copyright (c) 2024 Agora IO. All rights reserved. # # -import base64 import json -import threading import time import uuid +import asyncio + from ten import ( AudioFrame, VideoFrame, @@ -20,8 +20,6 @@ CmdResult, Data, ) -import asyncio - TEXT_DATA_TEXT_FIELD = "text" TEXT_DATA_FINAL_FIELD = "is_final" diff --git a/agents/ten_packages/extension/message_collector_rtm/src/log.py b/agents/ten_packages/extension/message_collector_rtm/src/log.py deleted file mode 100644 index 1cbb30f5..00000000 --- a/agents/ten_packages/extension/message_collector_rtm/src/log.py +++ /dev/null @@ -1,22 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by Wei Hu in 2024-08. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# -import logging - -logger = logging.getLogger("message_collector_rtm") -logger.setLevel(logging.INFO) - -formatter_str = ( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - " - "[%(filename)s:%(lineno)d] - %(message)s" -) -formatter = logging.Formatter(formatter_str) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/weatherapi_tool_python/__init__.py b/agents/ten_packages/extension/weatherapi_tool_python/__init__.py index f9c90d97..27de41c9 100644 --- a/agents/ten_packages/extension/weatherapi_tool_python/__init__.py +++ b/agents/ten_packages/extension/weatherapi_tool_python/__init__.py @@ -6,6 +6,3 @@ # # from . import addon -from .log import logger - -logger.info("weatherapi_tool_python extension loaded") diff --git a/agents/ten_packages/extension/weatherapi_tool_python/addon.py b/agents/ten_packages/extension/weatherapi_tool_python/addon.py index ba56a0c6..e34608d7 100644 --- a/agents/ten_packages/extension/weatherapi_tool_python/addon.py +++ b/agents/ten_packages/extension/weatherapi_tool_python/addon.py @@ -17,6 +17,5 @@ class WeatherToolExtensionAddon(Addon): def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: from .extension import WeatherToolExtension - from .log import logger - logger.info("WeatherToolExtensionAddon on_create_instance") + ten_env.log_info("WeatherToolExtensionAddon on_create_instance") ten_env.on_create_instance_done(WeatherToolExtension(name), context) diff --git a/agents/ten_packages/extension/weatherapi_tool_python/extension.py b/agents/ten_packages/extension/weatherapi_tool_python/extension.py index f9d38543..8614914a 100644 --- a/agents/ten_packages/extension/weatherapi_tool_python/extension.py +++ b/agents/ten_packages/extension/weatherapi_tool_python/extension.py @@ -8,25 +8,17 @@ import json import aiohttp -import requests from typing import Any +from dataclasses import dataclass + +from ten import Cmd -from ten import ( - AudioFrame, - VideoFrame, - Extension, - TenEnv, - Cmd, - StatusCode, - CmdResult, - Data, -) from ten.async_ten_env import AsyncTenEnv from ten_ai_base.helper import get_properties_string +from ten_ai_base import BaseConfig from ten_ai_base.llm_tool import AsyncLLMToolBaseExtension -from ten_ai_base.types import LLMChatCompletionToolMessageParam, LLMChatCompletionUserMessageParam, LLMToolMetadata, LLMToolMetadataParameter, LLMToolResult -from .log import logger +from ten_ai_base.types import LLMToolMetadata, LLMToolMetadataParameter, LLMToolResult CMD_TOOL_REGISTER = "tool_register" CMD_TOOL_CALL = "tool_call" @@ -85,12 +77,16 @@ PROPERTY_API_KEY = "api_key" # Required +@dataclass +class WeatherToolConfig(BaseConfig): + api_key: str = "" class WeatherToolExtension(AsyncLLMToolBaseExtension): def __init__(self, name: str) -> None: super().__init__(name) - self.api_key = None self.session = None + self.ten_env = None + self.config : WeatherToolConfig = None async def on_init(self, ten_env: AsyncTenEnv) -> None: ten_env.log_debug("on_init") @@ -98,13 +94,13 @@ async def on_init(self, ten_env: AsyncTenEnv) -> None: async def on_start(self, ten_env: AsyncTenEnv) -> None: ten_env.log_debug("on_start") - await super().on_start(ten_env) - get_properties_string( - ten_env, [PROPERTY_API_KEY], lambda name, value: setattr(self, name, value)) - if not self.api_key: - ten_env.log_info(f"API key is missing, exiting on_start") - return + self.config = WeatherToolConfig.create(ten_env=ten_env) + ten_env.log_info(f"config: {self.config}") + if self.config.api_key: + await super().on_start(ten_env) + + self.ten_env = ten_env async def on_stop(self, ten_env: AsyncTenEnv) -> None: ten_env.log_debug("on_stop") @@ -188,7 +184,7 @@ async def _get_current_weather(self, args: dict) -> Any: raise Exception("Failed to get property") location = args["location"] - url = f"http://api.weatherapi.com/v1/current.json?key={self.api_key}&q={location}&aqi=no" + url = f"http://api.weatherapi.com/v1/current.json?key={self.config.api_key}&q={location}&aqi=no" async with self.session.get(url) as response: result = await response.json() @@ -205,7 +201,7 @@ async def _get_past_weather(self, args: dict) -> Any: location = args["location"] datetime = args["datetime"] - url = f"http://api.weatherapi.com/v1/history.json?key={self.api_key}&q={location}&dt={datetime}" + url = f"http://api.weatherapi.com/v1/history.json?key={self.config.api_key}&q={location}&dt={datetime}" async with self.session.get(url) as response: result = await response.json() @@ -221,13 +217,13 @@ async def _get_future_weather(self, args: dict) -> Any: raise Exception("Failed to get property") location = args["location"] - url = f"http://api.weatherapi.com/v1/forecast.json?key={self.api_key}&q={location}&days=3&aqi=no&alerts=no" + url = f"http://api.weatherapi.com/v1/forecast.json?key={self.config.api_key}&q={location}&days=3&aqi=no&alerts=no" async with self.session.get(url) as response: result = await response.json() # Log the result - logger.info(f"get result {result}") + self.ten_env.log_info(f"get result {result}") # Remove all hourly data for d in result.get("forecast", {}).get("forecastday", []): diff --git a/agents/ten_packages/extension/weatherapi_tool_python/log.py b/agents/ten_packages/extension/weatherapi_tool_python/log.py deleted file mode 100644 index b2bda5e2..00000000 --- a/agents/ten_packages/extension/weatherapi_tool_python/log.py +++ /dev/null @@ -1,22 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by Tomas Liu in 2024-08. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# -import logging - -logger = logging.getLogger("weatherapi_tool_python") -logger.setLevel(logging.INFO) - -formatter_str = ( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - " - "[%(filename)s:%(lineno)d] - %(message)s" -) -formatter = logging.Formatter(formatter_str) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) From 885e3d7d28271ada6a9292d3d22f44f13a6e0173 Mon Sep 17 00:00:00 2001 From: TomasLiu Date: Tue, 19 Nov 2024 14:08:20 +0800 Subject: [PATCH 2/2] revert manifest change --- .../deepgram_asr_python/manifest.json | 3 ++- .../message_collector_rtm/manifest.json | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/agents/ten_packages/extension/deepgram_asr_python/manifest.json b/agents/ten_packages/extension/deepgram_asr_python/manifest.json index 9e3298f6..e7914dd6 100644 --- a/agents/ten_packages/extension/deepgram_asr_python/manifest.json +++ b/agents/ten_packages/extension/deepgram_asr_python/manifest.json @@ -26,7 +26,8 @@ }, "audio_frame_in": [ { - "name": "pcm_frame" + "name": "pcm_frame", + "property": {} } ], "cmd_in": [ diff --git a/agents/ten_packages/extension/message_collector_rtm/manifest.json b/agents/ten_packages/extension/message_collector_rtm/manifest.json index f5b2a335..506f2e01 100644 --- a/agents/ten_packages/extension/message_collector_rtm/manifest.json +++ b/agents/ten_packages/extension/message_collector_rtm/manifest.json @@ -48,6 +48,18 @@ "type": "string" } } + }, + { + "name": "rtm_storage_event", + "property": {} + }, + { + "name": "rtm_presence_event", + "property": {} + }, + { + "name": "rtm_lock_event", + "property": {} } ], "data_out": [ @@ -63,6 +75,12 @@ } } ], + "cmd_in": [ + { + "name": "on_user_audio_track_state_changed", + "property": {} + } + ], "cmd_out": [ { "name": "publish",