From c19ec722e24991dd1ca3028845c280fb290cdbca Mon Sep 17 00:00:00 2001 From: TomasLiu Date: Mon, 25 Nov 2024 12:26:24 +0800 Subject: [PATCH 1/2] refact openai v2v --- agents/examples/demo/property.json | 4 +- agents/examples/experimental/property.json | 10 +- .../examples/openai_wrapper.py | 5 - .../extension/glue_python_async/extension.py | 208 +++-- .../extension/glue_python_async/manifest.json | 38 + .../glue_python_async/requirements.txt | 1 + .../extension/glue_python_async/schema.yml | 3 + .../extension/openai_v2v_python/__init__.py | 3 - .../extension/openai_v2v_python/addon.py | 9 +- .../extension/openai_v2v_python/conf.py | 50 -- .../extension/openai_v2v_python/extension.py | 799 +++++++----------- .../extension/openai_v2v_python/log.py | 22 - .../extension/openai_v2v_python/manifest.json | 38 +- .../openai_v2v_python/realtime/connection.py | 22 +- .../extension/openai_v2v_python/tools.py | 91 -- .../interface/ten_ai_base/__init__.py | 3 +- .../interface/ten_ai_base/chat_memory.py | 25 +- .../interface/ten_ai_base/config.py | 13 +- demo/src/app/api/agents/start/graph.tsx | 4 +- 19 files changed, 569 insertions(+), 779 deletions(-) create mode 100644 agents/ten_packages/extension/glue_python_async/requirements.txt delete mode 100644 agents/ten_packages/extension/openai_v2v_python/conf.py delete mode 100644 agents/ten_packages/extension/openai_v2v_python/log.py delete mode 100644 agents/ten_packages/extension/openai_v2v_python/tools.py diff --git a/agents/examples/demo/property.json b/agents/examples/demo/property.json index 6b357d11..2945bcfd 100644 --- a/agents/examples/demo/property.json +++ b/agents/examples/demo/property.json @@ -518,7 +518,7 @@ "language": "en-US", "server_vad": true, "dump": true, - "history": 10 + "max_history": 10 } }, { @@ -848,7 +848,7 @@ "language": "en-US", "server_vad": true, "dump": true, - "history": 10 + "max_history": 10 } }, { diff --git a/agents/examples/experimental/property.json b/agents/examples/experimental/property.json index b532fea5..aa9fecee 100644 --- a/agents/examples/experimental/property.json +++ b/agents/examples/experimental/property.json @@ -233,11 +233,11 @@ "language": "en-US", "server_vad": true, "dump": true, - "history": 10, + "max_history": 10, "vendor": "azure", "base_uri": "${env:AZURE_OPENAI_REALTIME_BASE_URI}", "path": "/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview", - "system_message": "" + "prompt": "" } }, { @@ -2444,7 +2444,7 @@ "language": "en-US", "server_vad": true, "dump": true, - "history": 10 + "max_history": 10 } }, { @@ -2566,7 +2566,7 @@ "language": "en-US", "server_vad": true, "dump": true, - "history": 10 + "max_history": 10 } }, { @@ -2724,7 +2724,7 @@ "language": "en-US", "server_vad": true, "dump": true, - "history": 10, + "max_history": 10, "enable_storage": true } }, diff --git a/agents/ten_packages/extension/glue_python_async/examples/openai_wrapper.py b/agents/ten_packages/extension/glue_python_async/examples/openai_wrapper.py index 8cde92f7..55e4053f 100644 --- a/agents/ten_packages/extension/glue_python_async/examples/openai_wrapper.py +++ b/agents/ten_packages/extension/glue_python_async/examples/openai_wrapper.py @@ -14,11 +14,6 @@ from fastapi import Depends, FastAPI, HTTPException, Request import asyncio -# Enable Pydantic debug mode -from pydantic import BaseConfig - -BaseConfig.debug = True - # Set up logging logging.config.dictConfig({ "version": 1, diff --git a/agents/ten_packages/extension/glue_python_async/extension.py b/agents/ten_packages/extension/glue_python_async/extension.py index ceb9d235..02ecc4eb 100644 --- a/agents/ten_packages/extension/glue_python_async/extension.py +++ b/agents/ten_packages/extension/glue_python_async/extension.py @@ -7,10 +7,12 @@ import traceback import aiohttp import json +import time +import re -from datetime import datetime +import numpy as np from typing import List, Any, AsyncGenerator -from dataclasses import dataclass +from dataclasses import dataclass, field from pydantic import BaseModel from ten import ( @@ -23,7 +25,7 @@ Data, ) -from ten_ai_base import BaseConfig, ChatMemory +from ten_ai_base import BaseConfig, ChatMemory, LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails, EVENT_MEMORY_APPENDED from ten_ai_base.llm import AsyncLLMBaseExtension, LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata from ten_ai_base.types import LLMChatCompletionUserMessageParam, LLMToolResult @@ -84,27 +86,9 @@ class Choice(BaseModel): index: int finish_reason: str | None -class CompletionTokensDetails(BaseModel): - accepted_prediction_tokens: int = 0 - audio_tokens: int = 0 - reasoning_tokens: int = 0 - rejected_prediction_tokens: int = 0 - -class PromptTokensDetails(BaseModel): - audio_tokens: int = 0 - cached_tokens: int = 0 - -class Usage(BaseModel): - completion_tokens: int = 0 - prompt_tokens: int = 0 - total_tokens: int = 0 - - completion_tokens_details: CompletionTokensDetails | None = None - prompt_tokens_details: PromptTokensDetails | None = None - class ResponseChunk(BaseModel): choices: List[Choice] - usage: Usage | None = None + usage: LLMUsage | None = None @dataclass class GlueConfig(BaseConfig): @@ -113,17 +97,29 @@ class GlueConfig(BaseConfig): prompt: str = "" max_history: int = 10 greeting: str = "" + failure_info: str = "" + modalities: List[str] = field(default_factory=lambda: ["text"]) + rtm_enabled: bool = True + ssml_enabled: bool = False + context_enabled: bool = False + extra_context: dict = field(default_factory=dict) + enable_storage: bool = False class AsyncGlueExtension(AsyncLLMBaseExtension): config : GlueConfig = None - sentence_fragment: str = "" ten_env: AsyncTenEnv = None loop: asyncio.AbstractEventLoop = None stopped: bool = False memory: ChatMemory = None - total_usage: Usage = Usage() + total_usage: LLMUsage = LLMUsage() users_count = 0 + completion_times = [] + connect_times = [] + first_token_times = [] + + remote_stream_id: int = 999 # TODO + async def on_init(self, ten_env: AsyncTenEnv) -> None: await super().on_init(ten_env) ten_env.log_debug("on_init") @@ -138,6 +134,7 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: ten_env.log_info(f"config: {self.config}") self.memory = ChatMemory(self.config.max_history) + self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) self.ten_env = ten_env @@ -187,7 +184,21 @@ async def on_data_chat_completion(self, ten_env: AsyncTenEnv, **kargs: LLMDataCo messages = [] if self.config.prompt: messages.append({"role": "system", "content": self.config.prompt}) - messages.extend(self.memory.get()) + + history = self.memory.get() + while history: + if history[0].get("role") == "tool": + history = history[1:] + continue + if history[0].get("role") == "assistant" and history[0].get("tool_calls"): + history = history[1:] + continue + + # Skip the first tool role + break + + messages.extend(history) + if not input: ten_env.log_warn("No message in data") else: @@ -220,6 +231,10 @@ def tool_dict(tool: LLMToolMetadata): json["function"]["parameters"]["required"].append(param.name) return json + + def trim_xml(input_string): + return re.sub(r'<[^>]+>', '', input_string).strip() + tools = [] for tool in self.available_tools: tools.append(tool_dict(tool)) @@ -229,16 +244,25 @@ def tool_dict(tool: LLMToolMetadata): calls = {} sentences = [] + start_time = time.time() + first_token_time = None response = self._stream_chat(messages=messages, tools=tools) async for message in response: - self.ten_env.log_info(f"content: {message}") + self.ten_env.log_debug(f"content: {message}") # TODO: handle tool call try: c = ResponseChunk(**message) if c.choices: if c.choices[0].delta.content: - total_output += c.choices[0].delta.content - sentences, sentence_fragment = parse_sentences(sentence_fragment, c.choices[0].delta.content) + if first_token_time is None: + first_token_time = time.time() + self.first_token_times.append(first_token_time - start_time) + + content = c.choices[0].delta.content + if self.config.ssml_enabled and content.startswith(""): + content = trim_xml(content) + total_output += content + sentences, sentence_fragment = parse_sentences(sentence_fragment, content) for s in sentences: await self._send_text(s) if c.choices[0].delta.tool_calls: @@ -252,10 +276,14 @@ def tool_dict(tool: LLMToolMetadata): calls[call.index].function.arguments += call.function.arguments if c.usage: self.ten_env.log_info(f"usage: {c.usage}") - self._update_usage(c.usage) + await self._update_usage(c.usage) except Exception as e: self.ten_env.log_error(f"Failed to parse response: {message} {e}") traceback.print_exc() + if sentence_fragment: + await self._send_text(sentence_fragment) + end_time = time.time() + self.completion_times.append(end_time - start_time) if total_output: self.memory.put({"role": "assistant", "content": total_output}) @@ -343,48 +371,67 @@ async def _send_text(self, text: str) -> None: self.ten_env.send_data(data) async def _stream_chat(self, messages: List[Any], tools: List[Any]) -> AsyncGenerator[dict, None]: - session = aiohttp.ClientSession() - try: - payload = { - "messages": messages, - "tools": tools, - "tools_choice": "auto" if tools else "none", - "model": "gpt-3.5-turbo", - "stream": True, - "stream_options": {"include_usage": True} - } - self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}") - headers = { - "Authorization": f"Bearer {self.config.token}", - "Content-Type": "application/json" - } - - async with session.post(self.config.api_url, json=payload, headers=headers) as response: - if response.status != 200: - r = await response.json() - self.ten_env.log_error(f"Received unexpected status {r} from the server.") - return + async with aiohttp.ClientSession() as session: + try: + payload = { + "messages": messages, + "tools": tools, + "tools_choice": "auto" if tools else "none", + "model": "gpt-3.5-turbo", + "stream": True, + "stream_options": {"include_usage": True}, + "ssml_enabled": self.config.ssml_enabled + } + if self.config.context_enabled: + payload["context"] = { + **self.config.extra_context + } + self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}") + headers = { + "Authorization": f"Bearer {self.config.token}", + "Content-Type": "application/json" + } - async for line in response.content: - if line: - l = line.decode('utf-8').strip() - if l.startswith("data:"): - content = l[5:].strip() - if content == "[DONE]": - break - self.ten_env.log_info(f"content: {content}") - yield json.loads(content) - except Exception as e: - self.ten_env.log_error(f"Failed to handle {e}") - finally: - await session.close() - session = None + start_time = time.time() + async with session.post(self.config.api_url, json=payload, headers=headers) as response: + if response.status != 200: + r = await response.json() + self.ten_env.log_error(f"Received unexpected status {r} from the server.") + if self.config.failure_info: + await self._send_text(self.config.failure_info) + return + end_time = time.time() + self.connect_times.append(end_time - start_time) + + async for line in response.content: + if line: + l = line.decode('utf-8').strip() + if l.startswith("data:"): + content = l[5:].strip() + if content == "[DONE]": + break + self.ten_env.log_debug(f"content: {content}") + yield json.loads(content) + except Exception as e: + traceback.print_exc() + self.ten_env.log_error(f"Failed to handle {e}") + finally: + await session.close() + session = None + + async def _update_usage(self, usage: LLMUsage) -> None: + if not self.config.rtm_enabled: + return - async def _update_usage(self, usage: Usage) -> None: self.total_usage.completion_tokens += usage.completion_tokens self.total_usage.prompt_tokens += usage.prompt_tokens self.total_usage.total_tokens += usage.total_tokens + if self.total_usage.completion_tokens_details is None: + self.total_usage.completion_tokens_details = LLMCompletionTokensDetails() + if self.total_usage.prompt_tokens_details is None: + self.total_usage.prompt_tokens_details = LLMPromptTokensDetails() + if usage.completion_tokens_details: self.total_usage.completion_tokens_details.accepted_prediction_tokens += usage.completion_tokens_details.accepted_prediction_tokens self.total_usage.completion_tokens_details.audio_tokens += usage.completion_tokens_details.audio_tokens @@ -395,4 +442,33 @@ async def _update_usage(self, usage: Usage) -> None: self.total_usage.prompt_tokens_details.audio_tokens += usage.prompt_tokens_details.audio_tokens self.total_usage.prompt_tokens_details.cached_tokens += usage.prompt_tokens_details.cached_tokens - self.ten_env.log_info(f"total usage: {self.total_usage}") \ No newline at end of file + self.ten_env.log_info(f"total usage: {self.total_usage}") + + data = Data.create("llm_stat") + data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump())) + if self.connect_times and self.completion_times and self.first_token_times: + data.set_property_from_json("latency", json.dumps({ + "connection_latency_95": np.percentile(self.connect_times, 95), + "completion_latency_95": np.percentile(self.completion_times, 95), + "first_token_latency_95": np.percentile(self.first_token_times, 95), + "connection_latency_99": np.percentile(self.connect_times, 99), + "completion_latency_99": np.percentile(self.completion_times, 99), + "first_token_latency_99": np.percentile(self.first_token_times, 99) + })) + self.ten_env.send_data(data) + + async def _on_memory_appended(self, message: dict) -> None: + self.ten_env.log_info(f"Memory appended: {message}") + if not self.config.enable_storage: + return + + role = message.get("role") + stream_id = self.remote_stream_id if role == "user" else 0 + try: + d = Data.create("append") + d.set_property_string("text", message.get("content")) + d.set_property_string("role", role) + d.set_property_int("stream_id", stream_id) + self.ten_env.send_data(d) + except Exception as e: + self.ten_env.log_error(f"Error send append_context data {message} {e}") diff --git a/agents/ten_packages/extension/glue_python_async/manifest.json b/agents/ten_packages/extension/glue_python_async/manifest.json index c772f171..a396372b 100644 --- a/agents/ten_packages/extension/glue_python_async/manifest.json +++ b/agents/ten_packages/extension/glue_python_async/manifest.json @@ -33,6 +33,31 @@ }, "prompt": { "type": "string" + }, + "greeting": { + "type": "string" + }, + "failure_info": { + "type": "string" + }, + "modalities": { + "type": "array", + "items": { + "type": "string" + } + }, + "rtm_enabled": { + "type": "bool" + }, + "ssml_enabled": { + "type": "bool" + }, + "context_enabled": { + "type": "bool" + }, + "extra_context": { + "type": "object", + "properties": {} } }, "data_in": [ @@ -53,6 +78,19 @@ "type": "string" } } + }, + { + "name": "llm_stat", + "property": { + "usage": { + "type": "object", + "properties": {} + }, + "latency": { + "type": "object", + "properties": {} + } + } } ], "cmd_in": [ diff --git a/agents/ten_packages/extension/glue_python_async/requirements.txt b/agents/ten_packages/extension/glue_python_async/requirements.txt new file mode 100644 index 00000000..296d6545 --- /dev/null +++ b/agents/ten_packages/extension/glue_python_async/requirements.txt @@ -0,0 +1 @@ +numpy \ No newline at end of file diff --git a/agents/ten_packages/extension/glue_python_async/schema.yml b/agents/ten_packages/extension/glue_python_async/schema.yml index 099fdff5..37a2b0b5 100644 --- a/agents/ten_packages/extension/glue_python_async/schema.yml +++ b/agents/ten_packages/extension/glue_python_async/schema.yml @@ -81,6 +81,9 @@ components: stream: type: boolean default: true + ssml_enabled: + type: boolean + default: false SystemMessage: type: object diff --git a/agents/ten_packages/extension/openai_v2v_python/__init__.py b/agents/ten_packages/extension/openai_v2v_python/__init__.py index 262c322e..8cd75dde 100644 --- a/agents/ten_packages/extension/openai_v2v_python/__init__.py +++ b/agents/ten_packages/extension/openai_v2v_python/__init__.py @@ -6,6 +6,3 @@ # # from . import addon -from .log import logger - -logger.info("openai_v2v_python extension loaded") diff --git a/agents/ten_packages/extension/openai_v2v_python/addon.py b/agents/ten_packages/extension/openai_v2v_python/addon.py index 79cc9e59..1bddfd2a 100644 --- a/agents/ten_packages/extension/openai_v2v_python/addon.py +++ b/agents/ten_packages/extension/openai_v2v_python/addon.py @@ -13,10 +13,9 @@ @register_addon_as_extension("openai_v2v_python") -class OpenAIV2VExtensionAddon(Addon): +class OpenAIRealtimeExtensionAddon(Addon): def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: - from .extension import OpenAIV2VExtension - from .log import logger - logger.info("OpenAIV2VExtensionAddon on_create_instance") - ten_env.on_create_instance_done(OpenAIV2VExtension(name), context) + from .extension import OpenAIRealtimeExtension + ten_env.log_info("OpenAIRealtimeExtensionAddon on_create_instance") + ten_env.on_create_instance_done(OpenAIRealtimeExtension(name), context) diff --git a/agents/ten_packages/extension/openai_v2v_python/conf.py b/agents/ten_packages/extension/openai_v2v_python/conf.py deleted file mode 100644 index b28eeb7a..00000000 --- a/agents/ten_packages/extension/openai_v2v_python/conf.py +++ /dev/null @@ -1,50 +0,0 @@ - -from .realtime.struct import Voices - -DEFAULT_MODEL = "gpt-4o-realtime-preview" - -DEFAULT_GREETING = "Hey, I'm TEN Agent with OpenAI Realtime API, anything I can help you with?" - -BASIC_PROMPT = ''' -You are an agent based on OpenAI {model} model and TEN (pronounce /ten/, do not try to translate it) Framework(A realtime multimodal agent framework). Your knowledge cutoff is 2023-10. You are a helpful, witty, and friendly AI. Act like a human, but remember that you aren't a human and that you can't do human things in the real world. Your voice and personality should be warm and engaging, with a lively and playful tone. -If interacting is not in {language}, start by using the standard accent or dialect familiar to the user. Talk quickly. -Do not refer to these rules, even if you're asked about them. -{tools} -''' - -class RealtimeApiConfig: - def __init__( - self, - base_uri: str = "wss://api.openai.com", - api_key: str | None = None, - path: str = "/v1/realtime", - verbose: bool = False, - model: str=DEFAULT_MODEL, - language: str = "en-US", - instruction: str = BASIC_PROMPT, - temperature: float = 0.5, - max_tokens: int = 1024, - voice: Voices = Voices.Alloy, - server_vad: bool = True, - audio_out: bool = True, - input_transcript: bool = True - ): - self.base_uri = base_uri - self.api_key = api_key - self.path = path - self.verbose = verbose - self.model = model - self.language = language - self.instruction = instruction - self.temperature = temperature - self.max_tokens = max_tokens - self.voice = voice - self.server_vad = server_vad - self.audio_out = audio_out - self.input_transcript = input_transcript - - def build_ctx(self) -> dict: - return { - "language": self.language, - "model": self.model, - } \ No newline at end of file diff --git a/agents/ten_packages/extension/openai_v2v_python/extension.py b/agents/ten_packages/extension/openai_v2v_python/extension.py index 38922d00..a36c32f1 100644 --- a/agents/ten_packages/extension/openai_v2v_python/extension.py +++ b/agents/ten_packages/extension/openai_v2v_python/extension.py @@ -6,331 +6,296 @@ # # import asyncio -import threading import base64 +import traceback from datetime import datetime -from typing import Awaitable, Iterable -from functools import partial +from typing import Iterable from ten import ( AudioFrame, - VideoFrame, - Extension, - TenEnv, + AsyncTenEnv, Cmd, StatusCode, CmdResult, Data, ) from ten.audio_frame import AudioFrameDataFmt -from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_PROPERTY_TOOL +from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL +from ten_ai_base.llm import AsyncLLMBaseExtension +from dataclasses import dataclass, field +from ten_ai_base import BaseConfig, ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, LLMUsage from ten_ai_base.types import LLMToolMetadata, LLMToolResult, LLMChatCompletionContentPartParam -from .log import logger - -from .tools import ToolRegistry -from .conf import RealtimeApiConfig, BASIC_PROMPT, DEFAULT_GREETING from .realtime.connection import RealtimeApiConnection from .realtime.struct import * -from .tools import ToolRegistry - -# properties -PROPERTY_API_KEY = "api_key" # Required -PROPERTY_BASE_URI = "base_uri" # Optional -PROPERTY_PATH = "path" # Optional -PROPERTY_VENDOR = "vendor" # Optional -PROPERTY_MODEL = "model" # Optional -PROPERTY_SYSTEM_MESSAGE = "system_message" # Optional -PROPERTY_TEMPERATURE = "temperature" # Optional -PROPERTY_MAX_TOKENS = "max_tokens" # Optional -PROPERTY_ENABLE_STORAGE = "enable_storage" # Optional -PROPERTY_VOICE = "voice" # Optional -PROPERTY_AUDIO_OUT = "audio_out" # Optional -PROPERTY_INPUT_TRANSCRIPT = "input_transcript" -PROPERTY_SERVER_VAD = "server_vad" # Optional -PROPERTY_STREAM_ID = "stream_id" -PROPERTY_LANGUAGE = "language" -PROPERTY_DUMP = "dump" -PROPERTY_GREETING = "greeting" -PROPERTY_HISTORY = "history" - -DEFAULT_VOICE = Voices.Alloy - -CMD_TOOL_REGISTER = "tool_register" -CMD_TOOL_CALL = "tool_call" -CMD_PROPERTY_NAME = "name" -CMD_PROPERTY_ARGS = "arguments" - -TOOL_REGISTER_PROPERTY_NAME = "name" -TOOL_REGISTER_PROPERTY_DESCRIPTON = "description" -TOOL_REGISTER_PROPERTY_PARAMETERS = "parameters" +CMD_IN_FLUSH = "flush" +CMD_IN_ON_USER_JOINED = "on_user_joined" +CMD_IN_ON_USER_LEFT = "on_user_left" +CMD_OUT_FLUSH = "flush" class Role(str, Enum): User = "user" Assistant = "assistant" +@dataclass +class OpenAIRealtimeConfig(BaseConfig): + base_uri: str = "wss://api.openai.com" + api_key: str = "" + path: str = "/v1/realtime" + model: str = "gpt-4o-realtime-preview" + language: str = "en-US" + prompt: str = "" + temperature: float = 0.5 + max_tokens: int = 1024 + voice: str = "alloy" + server_vad: bool = True + audio_out: bool = True + input_transcript: bool = True + sample_rate: int = 24000 + + vendor: str = "" + stream_id: int = 0 + dump: bool = False + greeting: str = "" + max_history: int = 20 + enable_storage: bool = False + + def build_ctx(self) -> dict: + return { + "language": self.language, + "model": self.model, + } -class OpenAIV2VExtension(Extension): - def __init__(self, name: str): - super().__init__(name) - - # handler - self.loop = asyncio.new_event_loop() - self.thread: threading.Thread = None - - # openai related - self.config: RealtimeApiConfig = RealtimeApiConfig() - self.conn: RealtimeApiConnection = None - self.connected: bool = False - self.session_id: str = "" - self.session: SessionUpdateParams = None - self.last_updated = None - self.ctx: dict = {} - - # audo related - self.sample_rate: int = 24000 - self.out_audio_buff: bytearray = b'' - self.audio_len_threshold: int = 10240 - self.transcript: str = '' +class OpenAIRealtimeExtension(AsyncLLMBaseExtension): + config: OpenAIRealtimeConfig = None + stopped: bool = False + connected: bool = False + buffer: bytearray = b'' + memory: ChatMemory = None + retrieved: list = [] + total_usage: LLMUsage = LLMUsage() + users_count = 0 - # misc. - self.greeting : str = DEFAULT_GREETING - self.vendor: str = "" - # max history store in context - self.max_history = 0 - self.history = [] - self.enable_storage: bool = False - self.retrieved = [] - self.remote_stream_id: int = 0 - self.stream_id: int = 0 - self.channel_name: str = "" - self.dump: bool = False - self.registry = ToolRegistry() + stream_id: int = 0 + remote_stream_id: int = 0 + channel_name: str = "" + audio_len_threshold: int = 5120 - def on_start(self, ten_env: TenEnv) -> None: - logger.info("OpenAIV2VExtension on_start") + completion_times = [] + connect_times = [] + first_token_times = [] - self._fetch_properties(ten_env) + buff: bytearray = b'' + transcript: str = "" + ctx: dict = {} - # Start async handler - def start_event_loop(loop): - asyncio.set_event_loop(loop) - loop.run_forever() + async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) + ten_env.log_debug("on_init") - self.thread = threading.Thread( - target=start_event_loop, args=(self.loop,)) - self.thread.start() + async def on_start(self, ten_env: AsyncTenEnv) -> None: + await super().on_start(ten_env) + ten_env.log_debug("on_start") - if self.enable_storage: - r = Cmd.create("retrieve") - ten_env.send_cmd(r, self.on_retrieved) + self.loop = asyncio.get_event_loop() - # self._register_local_tools() + self.config = OpenAIRealtimeConfig.create(ten_env=ten_env) + ten_env.log_info(f"config: {self.config}") - asyncio.run_coroutine_threadsafe(self._init_connection(), self.loop) + if not self.config.api_key: + ten_env.log_error("api_key is required") + return - ten_env.on_start_done() + try: + self.memory = ChatMemory(self.config.max_history) + self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired) + self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) + + if self.config.enable_storage: + result = await ten_env.send_cmd(Cmd.create("retrieve")) + if result.get_status_code() == StatusCode.OK: + try: + history = json.loads(result.get_property_string("response")) + self.retrieved = history + ten_env.log_info(f"on retrieve context {history}") + except Exception as e: + ten_env.log_error("Failed to handle retrieve result {e}") + else: + ten_env.log_warn("Failed to retrieve content") - def on_stop(self, ten_env: TenEnv) -> None: - logger.info("OpenAIV2VExtension on_stop") + + self.conn = RealtimeApiConnection( + ten_env=ten_env, + base_uri=self.config.base_uri, path=self.config.path, api_key=self.config.api_key, model=self.config.model, vendor=self.config.vendor) + ten_env.log_info(f"Finish init client") - self.connected = False + self.loop.create_task(self._loop()) + except Exception as e: + traceback.print_exc() + self.ten_env.log_error(f"Failed to init client {e}") - if self.thread: - self.loop.call_soon_threadsafe(self.loop.stop) - self.thread.join() - self.thread = None + self.ten_env = ten_env - ten_env.on_stop_done() + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + await super().on_stop(ten_env) + ten_env.log_info("on_stop") - def on_retrieved(self, ten_env:TenEnv, result:CmdResult) -> None: - if result.get_status_code() == StatusCode.OK: - try: - history = json.loads(result.get_property_string("response")) - if not self.last_updated: - # cache the history - # FIXME need to have json - if self.max_history and len(history) > self.max_history: - self.retrieved = history[len(history) - self.max_history:] - else: - self.retrieved = history - logger.info(f"on retrieve context {history} {self.retrieved}") - except: - logger.exception("Failed to handle retrieve result") - else: - logger.warning("Failed to retrieve content") + self.stopped = True - def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: + async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> None: try: stream_id = audio_frame.get_property_int("stream_id") - # logger.debug(f"on_audio_frame {stream_id}") if self.channel_name == "": self.channel_name = audio_frame.get_property_string("channel") if self.remote_stream_id == 0: self.remote_stream_id = stream_id - asyncio.run_coroutine_threadsafe( - self._run_client_loop(ten_env), self.loop) - logger.info(f"Start session for {stream_id}") frame_buf = audio_frame.get_buf() self._dump_audio_if_need(frame_buf, Role.User) - asyncio.run_coroutine_threadsafe( - self._on_audio(frame_buf), self.loop) - except: - logger.exception(f"OpenAIV2VExtension on audio frame failed") - - # Should not be here - def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: - pass + await self._on_audio(frame_buf) + except Exception as e: + traceback.print_exc() + self.ten_env.log_error(f"OpenAIV2VExtension on audio frame failed {e}") - # Should not be here - def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: + async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: cmd_name = cmd.get_name() - ten_env.log_info(f"on_cmd name {cmd_name}") - - if cmd_name == CMD_TOOL_REGISTER: - self._on_tool_register(ten_env, cmd) + ten_env.log_debug("on_cmd name {}".format(cmd_name)) + + status = StatusCode.OK + detail = "success" + + if cmd_name == CMD_IN_FLUSH: + # Will only flush if it is client side vad + await self._flush() + await ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) + ten_env.log_info("on flush") + elif cmd_name == CMD_IN_ON_USER_JOINED: + self.users_count += 1 + # Send greeting when first user joined + if self.connected and self.users_count == 1: + await self._greeting() + elif cmd_name == CMD_IN_ON_USER_LEFT: + self.users_count -= 1 + else: + # Register tool + await super().on_cmd(ten_env, cmd) + return - cmd_result = CmdResult.create(StatusCode.OK) + cmd_result = CmdResult.create(status) + cmd_result.set_property_string("detail", detail) ten_env.return_result(cmd_result, cmd) - # Should not be here - def on_data(self, ten_env: TenEnv, data: Data) -> None: + # Not support for now + async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: pass - def on_config_changed(self) -> None: - # update session again - return - - async def _init_connection(self): - try: - self.conn = RealtimeApiConnection( - base_uri=self.config.base_uri, path=self.config.path, api_key=self.config.api_key, model=self.config.model, vendor=self.vendor, verbose=False) - logger.info(f"Finish init client {self.config} {self.conn}") - except: - logger.exception(f"Failed to create client {self.config}") - - async def _run_client_loop(self, ten_env: TenEnv): + async def _loop(self): def get_time_ms() -> int: current_time = datetime.now() return current_time.microsecond // 1000 try: await self.conn.connect() - self.connected = True item_id = "" # For truncate response_id = "" content_index = 0 relative_start_ms = get_time_ms() flushed = set() - logger.info("Client loop started") + self.ten_env.log_info("Client loop started") async for message in self.conn.listen(): try: - # logger.info(f"Received message: {message.type}") + # self.ten_env.log_info(f"Received message: {message.type}") match message: case SessionCreated(): - ten_env.log_info( - f"Session is created: {message.session}") + self.connected = True + self.ten_env.log_info(f"Session is created: {message.session}") self.session_id = message.session.id self.session = message.session - update_msg = self._update_session(ten_env) - await self.conn.send_request(update_msg) + await self._update_session() if self.retrieved: - await self._append_retrieve() - logger.info(f"after append retrieve: {len(self.retrieved)}") - - text = self._greeting_text() - await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": text}]))) - await self.conn.send_request(ResponseCreate()) - - # update_conversation = self.update_conversation() - # await self.conn.send_request(update_conversation) + for r in self.retrieved: + if r["role"] == MessageRole.User: + await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}]))) + elif r["role"] == MessageRole.Assistant: + await self.conn.send_request(ItemCreate(item=AssistantMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}]))) + self.ten_env.log_info(f"after append retrieve: {len(self.retrieved)}") + + if not self.connected: + self.connected = True + await self._greeting() case ItemInputAudioTranscriptionCompleted(): - logger.info( - f"On request transcript {message.transcript}") - self._send_transcript( - ten_env, message.transcript, Role.User, True) - self._append_context(ten_env, message.transcript, self.remote_stream_id, Role.User) + self.ten_env.log_info(f"On request transcript {message.transcript}") + self._send_transcript(message.transcript, Role.User, True) + self.memory.put({"role": "user", "content": message.transcript, "id": message.item_id}) case ItemInputAudioTranscriptionFailed(): - logger.warning( - f"On request transcript failed {message.item_id} {message.error}") + self.ten_env.log_warn(f"On request transcript failed {message.item_id} {message.error}") case ItemCreated(): - logger.info(f"On item created {message.item}") - - if self.max_history and ("status" not in message.item or message.item["status"] == "completed"): - # need maintain the history - await self._append_history(message.item) + self.ten_env.log_info(f"On item created {message.item}") case ResponseCreated(): response_id = message.response.id - logger.info( + self.ten_env.log_info( f"On response created {response_id}") case ResponseDone(): id = message.response.id status = message.response.status - logger.info( + self.ten_env.log_info( f"On response done {id} {status}") - for item in message.response.output: - await self._append_history(item) if id == response_id: response_id = "" case ResponseAudioTranscriptDelta(): - logger.info( + self.ten_env.log_info( f"On response transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") if message.response_id in flushed: - logger.warning( + self.ten_env.log_warn( f"On flushed transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") continue - self._send_transcript( - ten_env, message.delta, Role.Assistant, False) + self._send_transcript(message.delta, Role.Assistant, False) case ResponseTextDelta(): - logger.info( + self.ten_env.log_info( f"On response text delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") if message.response_id in flushed: - logger.warning( + self.ten_env.log_warn( f"On flushed text delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") continue - self._send_transcript( - ten_env, message.delta, Role.Assistant, False) + self._send_transcript(message.delta, Role.Assistant, False) case ResponseAudioTranscriptDone(): - logger.info( + self.ten_env.log_info( f"On response transcript done {message.output_index} {message.content_index} {message.transcript}") if message.response_id in flushed: - logger.warning( + self.ten_env.log_warn( f"On flushed transcript done {message.response_id}") continue - self._append_context(ten_env, message.transcript, self.stream_id, Role.Assistant) + self.memory.put({"role": "assistant", "content": message.transcript, "id": message.item_id}) self.transcript = "" - self._send_transcript( - ten_env, "", Role.Assistant, True) + self._send_transcript("", Role.Assistant, True) case ResponseTextDone(): - logger.info( + self.ten_env.log_info( f"On response text done {message.output_index} {message.content_index} {message.text}") if message.response_id in flushed: - logger.warning( + self.ten_env.log_warn( f"On flushed text done {message.response_id}") continue self.transcript = "" - self._send_transcript( - ten_env, "", Role.Assistant, True) + self._send_transcript("", Role.Assistant, True) case ResponseOutputItemDone(): - logger.info(f"Output item done {message.item}") + self.ten_env.log_info(f"Output item done {message.item}") case ResponseOutputItemAdded(): - logger.info( + self.ten_env.log_info( f"Output item added {message.output_index} {message.item}") case ResponseAudioDelta(): if message.response_id in flushed: - logger.warning( + self.ten_env.log_warn( f"On flushed audio delta {message.response_id} {message.item_id} {message.content_index}") continue item_id = message.item_id content_index = message.content_index - self._on_audio_delta(ten_env, message.delta) + self._on_audio_delta(message.delta) case InputAudioBufferSpeechStarted(): - logger.info( + self.ten_env.log_info( f"On server listening, in response {response_id}, last item {item_id}") # Tuncate the on-going audio stream end_ms = get_time_ms() - relative_start_ms @@ -338,210 +303,116 @@ def get_time_ms() -> int: truncate = ItemTruncate( item_id=item_id, content_index=content_index, audio_end_ms=end_ms) await self.conn.send_request(truncate) - self._flush(ten_env) + if self.config.server_vad: + await self._flush() if response_id and self.transcript: transcript = self.transcript + "[interrupted]" - self._send_transcript( - ten_env, transcript, Role.Assistant, True) + self._send_transcript(transcript, Role.Assistant, True) self.transcript = "" # memory leak, change to lru later flushed.add(response_id) item_id = "" case InputAudioBufferSpeechStopped(): relative_start_ms = get_time_ms() - message.audio_end_ms - logger.info( + self.ten_env.log_info( f"On server stop listening, {message.audio_end_ms}, relative {relative_start_ms}") case ResponseFunctionCallArgumentsDone(): tool_call_id = message.call_id name = message.name arguments = message.arguments - logger.info(f"need to call func {name}") - # TODO rebuild this into async, or it will block the thread - await self.registry.on_func_call(tool_call_id, name, arguments, self._on_tool_output) + self.ten_env.log_info(f"need to call func {name}") + self.loop.create_task(self._handle_tool_call(tool_call_id, name, arguments)) case ErrorMessage(): - logger.error( + self.ten_env.log_error( f"Error message received: {message.error}") case _: - logger.debug(f"Not handled message {message}") - except: - logger.exception( - f"Error processing message: {message}") + self.ten_env.log_debug(f"Not handled message {message}") + except Exception as e: + traceback.print_exc() + self.ten_env.log_error( + f"Error processing message: {message} {e}") - logger.info("Client loop finished") - except: - logger.exception(f"Failed to handle loop") + self.ten_env.log_info("Client loop finished") + except Exception as e: + traceback.print_exc() + self.ten_env.log_error(f"Failed to handle loop {e}") # clear so that new session can be triggered self.connected = False self.remote_stream_id = 0 - async def _append_history(self, item: ItemParam) -> None: - logger.info(f"append item {item}") - self.history.append(item["id"]) - if len(self.history) > self.max_history: - to_remove = self.history[0] - logger.info(f"remove history {to_remove}") - await self.conn.send_request(ItemDelete(item_id=to_remove)) - self.history = self.history[1:] - - async def _on_audio(self, buff: bytearray): - self.out_audio_buff += buff - # Buffer audio - if len(self.out_audio_buff) >= self.audio_len_threshold and self.session_id != "": - await self.conn.send_audio_data(self.out_audio_buff) - # logger.info( - # f"Send audio frame to OpenAI: {len(self.out_audio_buff)}") - self.out_audio_buff = b'' - - def _fetch_properties(self, ten_env: TenEnv): - try: - api_key = ten_env.get_property_string(PROPERTY_API_KEY) - self.config.api_key = api_key - except Exception as err: - logger.info( - f"GetProperty required {PROPERTY_API_KEY} failed, err: {err}") + async def _on_memory_expired(self, message: dict) -> None: + self.ten_env.log_info(f"Memory expired: {message}") + item_id = message.get("item_id") + if item_id: + await self.conn.send_request(ItemDelete(item_id=item_id)) + + async def _on_memory_appended(self, message: dict) -> None: + self.ten_env.log_info(f"Memory appended: {message}") + if not self.config.enable_storage: return - - try: - base_uri = ten_env.get_property_string(PROPERTY_BASE_URI) - if base_uri: - self.config.base_uri = base_uri - except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_BASE_URI} error: {err}") - - try: - path = ten_env.get_property_string(PROPERTY_PATH) - if path: - self.config.path = path - except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_PATH} error: {err}") + role = message.get("role") + stream_id = self.remote_stream_id if role == Role.User else 0 try: - self.vendor = ten_env.get_property_string(PROPERTY_VENDOR) - except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_VENDOR} error: {err}") - - try: - model = ten_env.get_property_string(PROPERTY_MODEL) - if model: - self.config.model = model - except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_MODEL} error: {err}") + d = Data.create("append") + d.set_property_string("text", message.get("content")) + d.set_property_string("role", role) + d.set_property_int("stream_id", stream_id) + self.ten_env.send_data(d) + except Exception as e: + self.ten_env.log_error(f"Error send append_context data {message} {e}") - try: - system_message = ten_env.get_property_string( - PROPERTY_SYSTEM_MESSAGE) - if system_message: - self.config.instruction = BASIC_PROMPT + "\n" + system_message - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_SYSTEM_MESSAGE} error: {err}") + # Direction: IN + async def _on_audio(self, buff: bytearray): + self.buff += buff + # Buffer audio + if self.connected and len(self.buff) >= self.audio_len_threshold: + await self.conn.send_audio_data(self.buff) + self.buff = b'' - try: - temperature = ten_env.get_property_float(PROPERTY_TEMPERATURE) - self.config.temperature = float(temperature) - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_TEMPERATURE} failed, err: {err}" - ) + self.ctx = self.config.build_ctx() + self.ctx["greeting"] = self.config.greeting + + async def _update_session(self) -> None: + tools = [] + + def tool_dict(tool: LLMToolMetadata): + t = { + "type": "function", + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False + } + } - try: - audio_out = ten_env.get_property_bool(PROPERTY_AUDIO_OUT) - self.config.audio_out = audio_out - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_AUDIO_OUT} failed, err: {err}" - ) + for param in tool.parameters: + t["parameters"]["properties"][param.name] = { + "type": param.type, + "description": param.description + } + if param.required: + t["parameters"]["required"].append(param.name) - try: - input_transcript = ten_env.get_property_bool(PROPERTY_INPUT_TRANSCRIPT) - self.config.input_transcript = input_transcript - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_INPUT_TRANSCRIPT} failed, err: {err}" - ) - - try: - max_tokens = ten_env.get_property_int(PROPERTY_MAX_TOKENS) - if max_tokens > 0: - self.config.max_tokens = int(max_tokens) - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_MAX_TOKENS} failed, err: {err}" - ) + return t - try: - self.enable_storage = ten_env.get_property_bool(PROPERTY_ENABLE_STORAGE) - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_ENABLE_STORAGE} failed, err: {err}" - ) - - try: - voice = ten_env.get_property_string(PROPERTY_VOICE) - if voice: - # v = DEFAULT_VOICE - # if voice == "alloy": - # v = Voices.Alloy - # elif voice == "echo": - # v = Voices.Echo - # elif voice == "shimmer": - # v = Voices.Shimmer - self.config.voice = voice - except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_VOICE} error: {err}") - - try: - language = ten_env.get_property_string(PROPERTY_LANGUAGE) - if language: - self.config.language = language - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_LANGUAGE} error: {err}") - - try: - greeting = ten_env.get_property_string(PROPERTY_GREETING) - if greeting: - self.greeting = greeting - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_GREETING} error: {err}") - - try: - server_vad = ten_env.get_property_bool(PROPERTY_SERVER_VAD) - self.config.server_vad = server_vad - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_SERVER_VAD} failed, err: {err}" - ) - - try: - self.dump = ten_env.get_property_bool(PROPERTY_DUMP) - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_DUMP} error: {err}") + if self.available_tools: + tool_prompt = "You have several tools that you can get help from:\n" + for t in self.available_tools: + tool_prompt += f"- ***{t.name}***: {t.description}" + self.ctx["tools"] = tool_prompt + tools = [tool_dict(t) for t in self.available_tools] + prompt = self._replace(self.config.prompt) - try: - history = ten_env.get_property_int(PROPERTY_HISTORY) - if history: - self.max_history = history - except Exception as err: - logger.info( - f"GetProperty optional {PROPERTY_HISTORY} error: {err}") - - self.ctx = self.config.build_ctx() - self.ctx["greeting"] = self.greeting - - def _update_session(self, ten_env: TenEnv) -> SessionUpdate: - self.ctx["tools"] = self.registry.to_prompt() - prompt = self._replace(self.config.instruction) - self.last_updated = datetime.now() - tools = self.registry.get_tools() - ten_env.log_info(f"update session {prompt} {tools}") + self.ten_env.log_info(f"update session {prompt} {tools}") su = SessionUpdate(session=SessionUpdateParams( instructions=prompt, model=self.config.model, - tool_choice="auto", + tool_choice="auto" if self.available_tools else "none", tools=tools )) if self.config.audio_out: @@ -552,43 +423,27 @@ def _update_session(self, ten_env: TenEnv) -> SessionUpdate: if self.config.input_transcript: su.session.input_audio_transcription=InputAudioTranscription( model="whisper-1") - return su + await self.conn.send_request(su) + + async def on_tools_update(self, ten_env: AsyncTenEnv, tool: LLMToolMetadata) -> None: + """Called when a new tool is registered. Implement this method to process the new tool.""" + self.ten_env.log_info(f"on tools update {tool}") + await self._update_session() - async def _append_retrieve(self): - if self.retrieved: - for r in self.retrieved: - if r["role"] == MessageRole.User: - await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}]))) - elif r["role"] == MessageRole.Assistant: - await self.conn.send_request(ItemCreate(item=AssistantMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}]))) - - ''' - def _update_conversation(self) -> UpdateConversationConfig: - prompt = self._replace(self.config.system_message) - conf = UpdateConversationConfig() - conf.system_message = prompt - conf.temperature = self.config.temperature - conf.max_tokens = self.config.max_tokens - conf.tool_choice = "none" - conf.disable_audio = False - conf.output_audio_format = AudioFormats.PCM16 - return conf - ''' - def _replace(self, prompt: str) -> str: result = prompt for token, value in self.ctx.items(): result = result.replace("{"+token+"}", value) return result - def _on_audio_delta(self, ten_env: TenEnv, delta: bytes) -> None: + # Direction: OUT + def _on_audio_delta(self, delta: bytes) -> None: audio_data = base64.b64decode(delta) - logger.debug("on_audio_delta audio_data len {} samples {}".format( - len(audio_data), len(audio_data) // 2)) + self.ten_env.log_debug(f"on_audio_delta audio_data len {len(audio_data)} samples {len(audio_data) // 2}") self._dump_audio_if_need(audio_data, Role.Assistant) f = AudioFrame.create("pcm_frame") - f.set_sample_rate(self.sample_rate) + f.set_sample_rate(self.config.sample_rate) f.set_bytes_per_sample(2) f.set_number_of_channels(1) f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) @@ -597,23 +452,9 @@ def _on_audio_delta(self, ten_env: TenEnv, delta: bytes) -> None: buff = f.lock_buf() buff[:] = audio_data f.unlock_buf(buff) - ten_env.send_audio_frame(f) - - def _append_context(self, ten_env: TenEnv, sentence: str, stream_id: int, role: str): - if not self.enable_storage: - return - - try: - d = Data.create("append") - d.set_property_string("text", sentence) - d.set_property_string("role", role) - d.set_property_int("stream_id", stream_id) - logger.info(f"append_contexttext [{sentence}] stream_id {stream_id} role {role}") - ten_env.send_data(d) - except: - logger.exception(f"Error send append_context data {role}: {sentence}") + self.ten_env.send_audio_frame(f) - def _send_transcript(self, ten_env: TenEnv, content: str, role: Role, is_final: bool) -> None: + def _send_transcript(self, content: str, role: Role, is_final: bool) -> None: def is_punctuation(char): if char in [",", ",", ".", "。", "?", "?", "!", "!"]: return True @@ -634,111 +475,64 @@ def parse_sentences(sentence_fragment, content): remain = current_sentence # Any remaining characters form the incomplete sentence return sentences, remain - def send_data(ten_env: TenEnv, sentence: str, stream_id: int, role: str, is_final: bool): + def send_data(ten_env: AsyncTenEnv, sentence: str, stream_id: int, role: str, is_final: bool): try: d = Data.create("text_data") d.set_property_string("text", sentence) d.set_property_bool("end_of_segment", is_final) d.set_property_string("role", role) d.set_property_int("stream_id", stream_id) - logger.info( + ten_env.log_info( f"send transcript text [{sentence}] stream_id {stream_id} is_final {is_final} end_of_segment {is_final} role {role}") ten_env.send_data(d) - except: - logger.exception( - f"Error send text data {role}: {sentence} {is_final}") + except Exception as e: + ten_env.log_error(f"Error send text data {role}: {sentence} {is_final} {e}") stream_id = self.remote_stream_id if role == Role.User else 0 try: if role == Role.Assistant and not is_final: sentences, self.transcript = parse_sentences(self.transcript, content) for s in sentences: - send_data(ten_env, s, stream_id, role, is_final) + send_data(self.ten_env, s, stream_id, role, is_final) else: - send_data(ten_env, content, stream_id, role, is_final) - except: - logger.exception(f"Error send text data {role}: {content} {is_final}") - - def _flush(self, ten_env: TenEnv) -> None: - try: - c = Cmd.create("flush") - ten_env.send_cmd(c, lambda ten, result: logger.info("flush done")) - except: - logger.exception(f"Error flush") + send_data(self.ten_env, content, stream_id, role, is_final) + except Exception as e: + self.ten_env.log_error(f"Error send text data {role}: {content} {is_final} {e}") def _dump_audio_if_need(self, buf: bytearray, role: Role) -> None: - if not self.dump: + if not self.config.dump: return with open("{}_{}.pcm".format(role, self.channel_name), "ab") as dump_file: dump_file.write(buf) - #def _register_local_tools(self) -> None: - # self.ctx["tools"] = self.registry.to_prompt() + async def _handle_tool_call(self, tool_call_id: str, name: str, arguments: str) -> None: + self.ten_env.log_info(f"_handle_tool_call {tool_call_id} {name} {arguments}") + cmd: Cmd = Cmd.create(CMD_TOOL_CALL) + cmd.set_property_string("name", name) + cmd.set_property_from_json("arguments", arguments) + result: CmdResult = await self.ten_env.send_cmd(cmd) - def _on_tool_register(self, ten_env: TenEnv, cmd: Cmd): - try: - # name = cmd.get_property_string(TOOL_REGISTER_PROPERTY_NAME) - # description = cmd.get_property_string( - # TOOL_REGISTER_PROPERTY_DESCRIPTON) - # pstr = cmd.get_property_string(TOOL_REGISTER_PROPERTY_PARAMETERS) - # parameters = json.loads(pstr) - tool_metadata_json = json.loads( - cmd.get_property_to_json(CMD_PROPERTY_TOOL)) - ten_env.log_info(f"register tool: {tool_metadata_json}") - tool_metadata = LLMToolMetadata.model_validate_json( - tool_metadata_json) - p = partial(self._remote_tool_call, ten_env) - name = tool_metadata.name - description = tool_metadata.description - parameters = self._convert_tool_params_to_dict(tool_metadata) - self.registry.register( - name=name, description=description, - callback=p, - parameters=parameters) - logger.info(f"on tool register {name} {description}") - self.on_config_changed() - except: - logger.exception(f"Failed to register") - - async def _remote_tool_call(self, ten_env: TenEnv, name: str, args: str, callback: Awaitable): - logger.info(f"_remote_tool_call {name} {args}") - c:Cmd = Cmd.create(f"{CMD_TOOL_CALL}_{name}") - c.set_property_string(CMD_PROPERTY_NAME, name) - c.set_property_from_json(CMD_PROPERTY_ARGS, args) - ten_env.send_cmd(c, lambda ten, result: asyncio.run_coroutine_threadsafe( - callback(result), self.loop)) - logger.info(f"_remote_tool_call finish {name} {args}") - - async def _on_tool_output(self, tool_call_id:str, result:CmdResult): - state = result.get_status_code() tool_response = ItemCreate( item=FunctionCallOutputItemParam( call_id=tool_call_id, output="{\"success\":false}", ) ) - try: - if state == StatusCode.OK: - tool_result: LLMToolResult = json.loads(result.get_property_to_json(CMD_PROPERTY_RESULT)) - logger.info(f"_on_tool_output {tool_call_id} {tool_result}") - - result_content = tool_result["content"] - output = json.dumps(self._convert_to_content_parts(result_content)) - tool_response = ItemCreate( - item=FunctionCallOutputItemParam( - call_id=tool_call_id, - output=output, - ) - ) - else: - logger.error(f"Failed to call function {tool_call_id}") - - await self.conn.send_request(tool_response) - await self.conn.send_request(ResponseCreate()) - except: - logger.exception("Failed to handle tool output") - + if result.get_status_code() == StatusCode.OK: + tool_result: LLMToolResult = json.loads( + result.get_property_to_json(CMD_PROPERTY_RESULT)) + + result_content = tool_result["content"] + tool_response.item.output = json.dumps(self._convert_to_content_parts(result_content)) + self.ten_env.log_info(f"tool_result: {tool_call_id} {tool_result}") + else: + self.ten_env.log_error(f"Tool call failed") + + await self.conn.send_request(tool_response) + await self.conn.send_request(ResponseCreate()) + self.ten_env.log_info(f"_remote_tool_call finish {name} {arguments}") + def _greeting_text(self) -> str: text = "Hi, there." if self.config.language == "zh-CN": @@ -782,4 +576,17 @@ def _convert_to_content_parts(self, content: Iterable[LLMChatCompletionContentPa # Only text content is supported currently for v2v model if part["type"] == "text": content_parts.append(part) - return content_parts \ No newline at end of file + return content_parts + + async def _greeting(self) -> None: + if self.config.greeting: + text = self._greeting_text() + await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": text}]))) + await self.conn.send_request(ResponseCreate()) + + async def _flush(self) -> None: + try: + c = Cmd.create("flush") + await self.ten_env.send_cmd(c) + except: + self.ten_env.log_error(f"Error flush") diff --git a/agents/ten_packages/extension/openai_v2v_python/log.py b/agents/ten_packages/extension/openai_v2v_python/log.py deleted file mode 100644 index 3edfe294..00000000 --- a/agents/ten_packages/extension/openai_v2v_python/log.py +++ /dev/null @@ -1,22 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by Wei Hu in 2024-08. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# -import logging - -logger = logging.getLogger("openai_v2v_python") -logger.setLevel(logging.INFO) - -formatter_str = ( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - " - "[%(filename)s:%(lineno)d] - %(message)s" -) -formatter = logging.Formatter(formatter_str) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/openai_v2v_python/manifest.json b/agents/ten_packages/extension/openai_v2v_python/manifest.json index 8d0b6519..264b8b8e 100644 --- a/agents/ten_packages/extension/openai_v2v_python/manifest.json +++ b/agents/ten_packages/extension/openai_v2v_python/manifest.json @@ -23,29 +23,29 @@ }, "api": { "property": { - "api_key": { + "base_uri": { "type": "string" }, - "base_uri": { + "api_key": { "type": "string" }, "path": { "type": "string" }, - "vendor": { + "model": { "type": "string" }, - "temperature": { - "type": "float64" + "language": { + "type": "string" }, - "model": { + "prompt": { "type": "string" }, - "max_tokens": { - "type": "int64" + "temperature": { + "type": "float32" }, - "system_message": { - "type": "string" + "max_tokens": { + "type": "int32" }, "voice": { "type": "string" @@ -53,17 +53,29 @@ "server_vad": { "type": "bool" }, - "language": { + "audio_out": { + "type": "bool" + }, + "input_transcript": { + "type": "bool" + }, + "sample_rate": { + "type": "int32" + }, + "vendor": { "type": "string" }, + "stream_id": { + "type": "int32" + }, "dump": { "type": "bool" }, "greeting": { "type": "string" }, - "history": { - "type": "int64" + "max_history": { + "type": "int32" }, "enable_storage": { "type": "bool" diff --git a/agents/ten_packages/extension/openai_v2v_python/realtime/connection.py b/agents/ten_packages/extension/openai_v2v_python/realtime/connection.py index 94e81c79..33d7810f 100644 --- a/agents/ten_packages/extension/openai_v2v_python/realtime/connection.py +++ b/agents/ten_packages/extension/openai_v2v_python/realtime/connection.py @@ -3,11 +3,11 @@ import json import os import aiohttp -import uuid + +from ten import AsyncTenEnv from typing import Any, AsyncGenerator from .struct import InputAudioBufferAppend, ClientToServerMessage, ServerToClientMessage, parse_server_message, to_json -from ..log import logger DEFAULT_VIRTUAL_MODEL = "gpt-4o-realtime-preview" @@ -34,13 +34,15 @@ def smart_str(s: str, max_field_len: int = 128) -> str: class RealtimeApiConnection: def __init__( self, + ten_env: AsyncTenEnv, base_uri: str, api_key: str | None = None, path: str = "/v1/realtime", model: str = DEFAULT_VIRTUAL_MODEL, vendor: str = "", - verbose: bool = False, + verbose: bool = False ): + self.ten_env = ten_env self.vendor = vendor self.url = f"{base_uri}{path}" if not self.vendor and "model=" not in self.url: @@ -84,30 +86,30 @@ async def send_request(self, message: ClientToServerMessage): assert self.websocket is not None message_str = to_json(message) if self.verbose: - logger.info(f"-> {smart_str(message_str)}") + self.ten_env.log_info(f"-> {smart_str(message_str)}") await self.websocket.send_str(message_str) async def listen(self) -> AsyncGenerator[ServerToClientMessage, None]: assert self.websocket is not None if self.verbose: - logger.info("Listening for realtimeapi messages") + self.ten_env.log_info("Listening for realtimeapi messages") try: async for msg in self.websocket: if msg.type == aiohttp.WSMsgType.TEXT: if self.verbose: - logger.info(f"<- {smart_str(msg.data)}") + self.ten_env.log_info(f"<- {smart_str(msg.data)}") yield self.handle_server_message(msg.data) elif msg.type == aiohttp.WSMsgType.ERROR: - logger.error("Error during receive: %s", self.websocket.exception()) + self.ten_env.log_error("Error during receive: %s", self.websocket.exception()) break except asyncio.CancelledError: - logger.info("Receive messages task cancelled") + self.ten_env.log_info("Receive messages task cancelled") def handle_server_message(self, message: str) -> ServerToClientMessage: try: return parse_server_message(message) - except: - logger.exception("Error handling message") + except Exception as e: + self.ten_env.log_info(f"Error handling message {e}") async def close(self): # Close the websocket connection if it exists diff --git a/agents/ten_packages/extension/openai_v2v_python/tools.py b/agents/ten_packages/extension/openai_v2v_python/tools.py deleted file mode 100644 index a3e532d9..00000000 --- a/agents/ten_packages/extension/openai_v2v_python/tools.py +++ /dev/null @@ -1,91 +0,0 @@ -import copy -from typing import Dict, Any -from functools import partial - -from .log import logger - -class ToolRegistry: - tools: Dict[str, dict[str, Any]] = {} - def register(self, name:str, description: str, callback, parameters: Any = None) -> None: - info = { - "type": "function", - "name": name, - "description": description, - "callback": callback - } - if parameters: - info["parameters"] = parameters - self.tools[name] = info - logger.info(f"register tool {name} {description}") - - def to_prompt(self) -> str: - prompt = "" - if self.tools: - prompt = "You have several tools that you can get help from:\n" - for name, t in self.tools.items(): - desc = t["description"] - prompt += f"- ***{name}***: {desc}" - return prompt - - def unregister(self, name:str) -> None: - if name in self.tools: - del self.tools[name] - logger.info(f"unregister tool {name}") - - def get_tools(self) -> list[dict[str, Any]]: - result = [] - for _, t in self.tools.items(): - info = copy.copy(t) - del info["callback"] - result.append(info) - return result - - async def on_func_call(self, call_id: str, name: str, args: str, callback): - try: - if name in self.tools: - t = self.tools[name] - # FIXME add args check - if t.get("callback"): - p = partial(callback, call_id) - await t["callback"](name, args, p) - else: - logger.warning(f"Failed to find func {name}") - except: - logger.exception(f"Failed to call func {name}") - # TODO What to do if func call is dead - callback(None) - -if __name__ == "__main__": - r = ToolRegistry() - - def weather_check(location:str = "", datetime:str = ""): - logger.info(f"on weather check {location}, {datetime}") - - def on_tool_completion(result: Any): - logger.info(f"on tool completion {result}") - - r.register( - name="weather", description="This is a weather check func, if the user is asking about the weather. you need to summarize location and time information from the context as parameters. if the information is lack, please ask for more detail before calling.", - callback=weather_check, - parameters={ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location or region for the weather check.", - }, - "datetime": { - "type": "string", - "description": "The date and time for the weather check. The datetime should use format like 2024-10-01T16:42:00.", - } - }, - "required": ["location"], - }) - print(r.to_prompt()) - print(r.get_tools()) - print(r.on_func_call("weather", {"location":"LA", "datetime":"2024-10-01T16:43:01"}, on_tool_completion)) - r.unregister("weather") - print(r.to_prompt()) - print(r.get_tools()) - print(r.on_func_call("weather", {"location":"LA", "datetime":"2024-10-01T16:43:01"}, on_tool_completion)) - \ No newline at end of file diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/__init__.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/__init__.py index 9c387469..167e6afe 100644 --- a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/__init__.py +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/__init__.py @@ -5,9 +5,10 @@ # from .types import LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata, LLMToolResult, LLMChatCompletionMessageParam +from .usage import LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails from .llm import AsyncLLMBaseExtension from .llm_tool import AsyncLLMToolBaseExtension -from .chat_memory import ChatMemory +from .chat_memory import ChatMemory, EVENT_MEMORY_APPENDED, EVENT_MEMORY_EXPIRED from .helper import AsyncQueue, AsyncEventEmitter from .config import BaseConfig diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/chat_memory.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/chat_memory.py index 8ef98b10..3a1c9b1e 100644 --- a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/chat_memory.py +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/chat_memory.py @@ -4,28 +4,35 @@ # See the LICENSE file for more information. # import threading +import asyncio +from typing import Dict, List + +EVENT_MEMORY_EXPIRED = "memory_expired" +EVENT_MEMORY_APPENDED = "memory_appended" class ChatMemory: def __init__(self, max_history_length): self.max_history_length = max_history_length self.history = [] self.mutex = threading.Lock() # TODO: no need lock for asyncio + self.listeners: Dict[str, List] = {} def put(self, message): with self.mutex: self.history.append(message) + self.emit(EVENT_MEMORY_APPENDED, message) while True: history_count = len(self.history) if history_count > 0 and history_count > self.max_history_length: - self.history.pop(0) + self.emit(EVENT_MEMORY_EXPIRED, self.history.pop(0)) continue - if history_count > 0 and self.history[0]["role"] == "assistant": + if history_count > 0 and (self.history[0]["role"] == "assistant" or self.history[0]["role"] == "tool"): # we cannot have an assistant message at the start of the chat history # if after removal of the first, we have an assistant message, # we need to remove the assistant message too - self.history.pop(0) + self.emit(EVENT_MEMORY_EXPIRED, self.history.pop(0)) continue break @@ -40,3 +47,15 @@ def count(self): def clear(self): with self.mutex: self.history = [] + + def on(self, event_name, listener): + """Register an event listener.""" + if event_name not in self.listeners: + self.listeners[event_name] = [] + self.listeners[event_name].append(listener) + + def emit(self, event_name, *args, **kwargs): + """Fire the event without waiting for listeners to finish.""" + if event_name in self.listeners: + for listener in self.listeners[event_name]: + asyncio.create_task(listener(*args, **kwargs)) diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/config.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/config.py index ee7c45b4..673a8431 100644 --- a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/config.py +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/config.py @@ -1,7 +1,10 @@ -from dataclasses import dataclass, fields import builtins -from typing import TypeVar, Type +import json + +from typing import TypeVar, Type, List from ten import TenEnv +from dataclasses import dataclass, fields + T = TypeVar('T', bound='BaseConfig') @@ -28,7 +31,6 @@ def _init(obj, ten_env: TenEnv): # if not ten_env.is_property_exist(field.name): # continue try: - ten_env.log_info(f"init field.name: {field.name}") match field.type: case builtins.str: val = ten_env.get_property_string(field.name) @@ -44,6 +46,7 @@ def _init(obj, ten_env: TenEnv): val = ten_env.get_property_float(field.name) setattr(obj, field.name, val) case _: - pass + val = ten_env.get_property_to_json(field.name) + setattr(obj, field.name, json.loads(val)) except Exception as e: - ten_env.log_error(f"Error: {e}") + pass diff --git a/demo/src/app/api/agents/start/graph.tsx b/demo/src/app/api/agents/start/graph.tsx index a2f55dc2..fba485e1 100644 --- a/demo/src/app/api/agents/start/graph.tsx +++ b/demo/src/app/api/agents/start/graph.tsx @@ -127,7 +127,7 @@ export const getGraphProperties = ( "voice": voiceNameMap[language]["openai"][voiceType], "language": language, ...localizationOptions, - "system_message": prompt, + "prompt": prompt, "greeting": greeting, } } @@ -138,7 +138,7 @@ export const getGraphProperties = ( "voice": voiceNameMap[language]["openai"][voiceType], "language": language, ...localizationOptions, - "system_message": prompt, + "prompt": prompt, "greeting": greeting, }, "agora_rtc": { From 635bba276f2812d7f60e2f7c610b658319dbfc81 Mon Sep 17 00:00:00 2001 From: TomasLiu Date: Mon, 25 Nov 2024 18:37:44 +0800 Subject: [PATCH 2/2] reconnect opneai v2v --- .../extension/glue_python_async/extension.py | 14 +++ .../extension/openai_v2v_python/extension.py | 96 +++++++++++++++---- 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/agents/ten_packages/extension/glue_python_async/extension.py b/agents/ten_packages/extension/glue_python_async/extension.py index 02ecc4eb..56a13b60 100644 --- a/agents/ten_packages/extension/glue_python_async/extension.py +++ b/agents/ten_packages/extension/glue_python_async/extension.py @@ -134,6 +134,20 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: ten_env.log_info(f"config: {self.config}") self.memory = ChatMemory(self.config.max_history) + + if self.config.enable_storage: + result = await ten_env.send_cmd(Cmd.create("retrieve")) + if result.get_status_code() == StatusCode.OK: + try: + history = json.loads(result.get_property_string("response")) + for i in history: + self.memory.put(i) + ten_env.log_info(f"on retrieve context {history}") + except Exception as e: + ten_env.log_error("Failed to handle retrieve result {e}") + else: + ten_env.log_warn("Failed to retrieve content") + self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) self.ten_env = ten_env diff --git a/agents/ten_packages/extension/openai_v2v_python/extension.py b/agents/ten_packages/extension/openai_v2v_python/extension.py index a36c32f1..61912b23 100644 --- a/agents/ten_packages/extension/openai_v2v_python/extension.py +++ b/agents/ten_packages/extension/openai_v2v_python/extension.py @@ -8,6 +8,8 @@ import asyncio import base64 import traceback +import time +import numpy as np from datetime import datetime from typing import Iterable @@ -23,7 +25,7 @@ from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL from ten_ai_base.llm import AsyncLLMBaseExtension from dataclasses import dataclass, field -from ten_ai_base import BaseConfig, ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, LLMUsage +from ten_ai_base import BaseConfig, ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails from ten_ai_base.types import LLMToolMetadata, LLMToolResult, LLMChatCompletionContentPartParam from .realtime.connection import RealtimeApiConnection from .realtime.struct import * @@ -72,7 +74,6 @@ class OpenAIRealtimeExtension(AsyncLLMBaseExtension): connected: bool = False buffer: bytearray = b'' memory: ChatMemory = None - retrieved: list = [] total_usage: LLMUsage = LLMUsage() users_count = 0 @@ -88,6 +89,7 @@ class OpenAIRealtimeExtension(AsyncLLMBaseExtension): buff: bytearray = b'' transcript: str = "" ctx: dict = {} + input_end = time.time() async def on_init(self, ten_env: AsyncTenEnv) -> None: await super().on_init(ten_env) @@ -108,21 +110,22 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: try: self.memory = ChatMemory(self.config.max_history) - self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired) - self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) if self.config.enable_storage: result = await ten_env.send_cmd(Cmd.create("retrieve")) if result.get_status_code() == StatusCode.OK: try: history = json.loads(result.get_property_string("response")) - self.retrieved = history + for i in history: + self.memory.put(i) ten_env.log_info(f"on retrieve context {history}") except Exception as e: ten_env.log_error("Failed to handle retrieve result {e}") else: ten_env.log_warn("Failed to retrieve content") - + + self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired) + self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) self.conn = RealtimeApiConnection( ten_env=ten_env, @@ -155,6 +158,8 @@ async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> self._dump_audio_if_need(frame_buf, Role.User) await self._on_audio(frame_buf) + if not self.config.server_vad: + self.input_end = time.time() except Exception as e: traceback.print_exc() self.ten_env.log_error(f"OpenAIV2VExtension on audio frame failed {e}") @@ -197,7 +202,9 @@ def get_time_ms() -> int: return current_time.microsecond // 1000 try: + start_time = time.time() await self.conn.connect() + self.connect_times.append(time.time() - start_time) item_id = "" # For truncate response_id = "" content_index = 0 @@ -216,13 +223,14 @@ def get_time_ms() -> int: self.session = message.session await self._update_session() - if self.retrieved: - for r in self.retrieved: - if r["role"] == MessageRole.User: - await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}]))) - elif r["role"] == MessageRole.Assistant: - await self.conn.send_request(ItemCreate(item=AssistantMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}]))) - self.ten_env.log_info(f"after append retrieve: {len(self.retrieved)}") + history = self.memory.get() + for h in history: + if h["role"] == "user": + await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": h["content"]}]))) + elif h["role"] == "assistant": + await self.conn.send_request(ItemCreate(item=AssistantMessageItemParam(content=[{"type": ContentType.InputText, "text": h["content"]}]))) + self.ten_env.log_info(f"Finish send history {history}") + self.memory.clear() if not self.connected: self.connected = True @@ -242,10 +250,12 @@ def get_time_ms() -> int: case ResponseDone(): id = message.response.id status = message.response.status - self.ten_env.log_info( - f"On response done {id} {status}") if id == response_id: response_id = "" + self.ten_env.log_info( + f"On response done {id} {status} {message.response.usage}") + if message.response.usage: + await self._update_usage(message.response.usage) case ResponseAudioTranscriptDelta(): self.ten_env.log_info( f"On response transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") @@ -261,6 +271,9 @@ def get_time_ms() -> int: self.ten_env.log_warn( f"On flushed text delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") continue + if item_id != message.item_id: + item_id = message.item_id + self.first_token_times.append(time.time() - self.input_end) self._send_transcript(message.delta, Role.Assistant, False) case ResponseAudioTranscriptDone(): self.ten_env.log_info( @@ -279,6 +292,7 @@ def get_time_ms() -> int: self.ten_env.log_warn( f"On flushed text done {message.response_id}") continue + self.completion_times.append(time.time() - self.input_end) self.transcript = "" self._send_transcript("", Role.Assistant, True) case ResponseOutputItemDone(): @@ -291,9 +305,13 @@ def get_time_ms() -> int: self.ten_env.log_warn( f"On flushed audio delta {message.response_id} {message.item_id} {message.content_index}") continue - item_id = message.item_id + if item_id != message.item_id: + item_id = message.item_id + self.first_token_times.append(time.time() - self.input_end) content_index = message.content_index self._on_audio_delta(message.delta) + case ResponseAudioDone(): + self.completion_times.append(time.time() - self.input_end) case InputAudioBufferSpeechStarted(): self.ten_env.log_info( f"On server listening, in response {response_id}, last item {item_id}") @@ -313,6 +331,8 @@ def get_time_ms() -> int: flushed.add(response_id) item_id = "" case InputAudioBufferSpeechStopped(): + # Only for server vad + self.input_end = time.time() relative_start_ms = get_time_ms() - message.audio_end_ms self.ten_env.log_info( f"On server stop listening, {message.audio_end_ms}, relative {relative_start_ms}") @@ -340,6 +360,17 @@ def get_time_ms() -> int: # clear so that new session can be triggered self.connected = False self.remote_stream_id = 0 + + if not self.stopped: + await self.conn.close() + await asyncio.sleep(0.5) + self.ten_env.log_info("Reconnect") + + self.conn = RealtimeApiConnection( + ten_env=self.ten_env, + base_uri=self.config.base_uri, path=self.config.path, api_key=self.config.api_key, model=self.config.model, vendor=self.config.vendor) + + self.loop.create_task(self._loop()) async def _on_memory_expired(self, message: dict) -> None: self.ten_env.log_info(f"Memory expired: {message}") @@ -590,3 +621,36 @@ async def _flush(self) -> None: await self.ten_env.send_cmd(c) except: self.ten_env.log_error(f"Error flush") + + async def _update_usage(self, usage: dict) -> None: + self.total_usage.completion_tokens += usage.get("output_tokens") + self.total_usage.prompt_tokens += usage.get("input_tokens") + self.total_usage.total_tokens += usage.get("total_tokens") + if not self.total_usage.completion_tokens_details: + self.total_usage.completion_tokens_details = LLMCompletionTokensDetails() + if not self.total_usage.prompt_tokens_details: + self.total_usage.prompt_tokens_details = LLMPromptTokensDetails() + + if usage.get("output_token_details"): + self.total_usage.completion_tokens_details.accepted_prediction_tokens += usage["output_token_details"].get("text_tokens") + self.total_usage.completion_tokens_details.audio_tokens += usage["output_token_details"].get("audio_tokens") + + if usage.get("input_token_details:"): + self.total_usage.prompt_tokens_details.audio_tokens += usage["input_token_details"].get("audio_tokens") + self.total_usage.prompt_tokens_details.cached_tokens += usage["input_token_details"].get("cached_tokens") + self.total_usage.prompt_tokens_details.text_tokens += usage["input_token_details"].get("text_tokens") + + self.ten_env.log_info(f"total usage: {self.total_usage}") + + data = Data.create("llm_stat") + data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump())) + if self.connect_times and self.completion_times and self.first_token_times: + data.set_property_from_json("latency", json.dumps({ + "connection_latency_95": np.percentile(self.connect_times, 95), + "completion_latency_95": np.percentile(self.completion_times, 95), + "first_token_latency_95": np.percentile(self.first_token_times, 95), + "connection_latency_99": np.percentile(self.connect_times, 99), + "completion_latency_99": np.percentile(self.completion_times, 99), + "first_token_latency_99": np.percentile(self.first_token_times, 99) + })) + self.ten_env.send_data(data)