diff --git a/agents/ten_packages/extension/glue_python_async/extension.py b/agents/ten_packages/extension/glue_python_async/extension.py index 02ecc4eb..56a13b60 100644 --- a/agents/ten_packages/extension/glue_python_async/extension.py +++ b/agents/ten_packages/extension/glue_python_async/extension.py @@ -134,6 +134,20 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: ten_env.log_info(f"config: {self.config}") self.memory = ChatMemory(self.config.max_history) + + if self.config.enable_storage: + result = await ten_env.send_cmd(Cmd.create("retrieve")) + if result.get_status_code() == StatusCode.OK: + try: + history = json.loads(result.get_property_string("response")) + for i in history: + self.memory.put(i) + ten_env.log_info(f"on retrieve context {history}") + except Exception as e: + ten_env.log_error("Failed to handle retrieve result {e}") + else: + ten_env.log_warn("Failed to retrieve content") + self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) self.ten_env = ten_env diff --git a/agents/ten_packages/extension/openai_v2v_python/extension.py b/agents/ten_packages/extension/openai_v2v_python/extension.py index a36c32f1..61912b23 100644 --- a/agents/ten_packages/extension/openai_v2v_python/extension.py +++ b/agents/ten_packages/extension/openai_v2v_python/extension.py @@ -8,6 +8,8 @@ import asyncio import base64 import traceback +import time +import numpy as np from datetime import datetime from typing import Iterable @@ -23,7 +25,7 @@ from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL from ten_ai_base.llm import AsyncLLMBaseExtension from dataclasses import dataclass, field -from ten_ai_base import BaseConfig, ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, LLMUsage +from ten_ai_base import BaseConfig, ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails from ten_ai_base.types import LLMToolMetadata, LLMToolResult, LLMChatCompletionContentPartParam from .realtime.connection import RealtimeApiConnection from .realtime.struct import * @@ -72,7 +74,6 @@ class OpenAIRealtimeExtension(AsyncLLMBaseExtension): connected: bool = False buffer: bytearray = b'' memory: ChatMemory = None - retrieved: list = [] total_usage: LLMUsage = LLMUsage() users_count = 0 @@ -88,6 +89,7 @@ class OpenAIRealtimeExtension(AsyncLLMBaseExtension): buff: bytearray = b'' transcript: str = "" ctx: dict = {} + input_end = time.time() async def on_init(self, ten_env: AsyncTenEnv) -> None: await super().on_init(ten_env) @@ -108,21 +110,22 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: try: self.memory = ChatMemory(self.config.max_history) - self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired) - self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) if self.config.enable_storage: result = await ten_env.send_cmd(Cmd.create("retrieve")) if result.get_status_code() == StatusCode.OK: try: history = json.loads(result.get_property_string("response")) - self.retrieved = history + for i in history: + self.memory.put(i) ten_env.log_info(f"on retrieve context {history}") except Exception as e: ten_env.log_error("Failed to handle retrieve result {e}") else: ten_env.log_warn("Failed to retrieve content") - + + self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired) + self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended) self.conn = RealtimeApiConnection( ten_env=ten_env, @@ -155,6 +158,8 @@ async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> self._dump_audio_if_need(frame_buf, Role.User) await self._on_audio(frame_buf) + if not self.config.server_vad: + self.input_end = time.time() except Exception as e: traceback.print_exc() self.ten_env.log_error(f"OpenAIV2VExtension on audio frame failed {e}") @@ -197,7 +202,9 @@ def get_time_ms() -> int: return current_time.microsecond // 1000 try: + start_time = time.time() await self.conn.connect() + self.connect_times.append(time.time() - start_time) item_id = "" # For truncate response_id = "" content_index = 0 @@ -216,13 +223,14 @@ def get_time_ms() -> int: self.session = message.session await self._update_session() - if self.retrieved: - for r in self.retrieved: - if r["role"] == MessageRole.User: - await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}]))) - elif r["role"] == MessageRole.Assistant: - await self.conn.send_request(ItemCreate(item=AssistantMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}]))) - self.ten_env.log_info(f"after append retrieve: {len(self.retrieved)}") + history = self.memory.get() + for h in history: + if h["role"] == "user": + await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": h["content"]}]))) + elif h["role"] == "assistant": + await self.conn.send_request(ItemCreate(item=AssistantMessageItemParam(content=[{"type": ContentType.InputText, "text": h["content"]}]))) + self.ten_env.log_info(f"Finish send history {history}") + self.memory.clear() if not self.connected: self.connected = True @@ -242,10 +250,12 @@ def get_time_ms() -> int: case ResponseDone(): id = message.response.id status = message.response.status - self.ten_env.log_info( - f"On response done {id} {status}") if id == response_id: response_id = "" + self.ten_env.log_info( + f"On response done {id} {status} {message.response.usage}") + if message.response.usage: + await self._update_usage(message.response.usage) case ResponseAudioTranscriptDelta(): self.ten_env.log_info( f"On response transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") @@ -261,6 +271,9 @@ def get_time_ms() -> int: self.ten_env.log_warn( f"On flushed text delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") continue + if item_id != message.item_id: + item_id = message.item_id + self.first_token_times.append(time.time() - self.input_end) self._send_transcript(message.delta, Role.Assistant, False) case ResponseAudioTranscriptDone(): self.ten_env.log_info( @@ -279,6 +292,7 @@ def get_time_ms() -> int: self.ten_env.log_warn( f"On flushed text done {message.response_id}") continue + self.completion_times.append(time.time() - self.input_end) self.transcript = "" self._send_transcript("", Role.Assistant, True) case ResponseOutputItemDone(): @@ -291,9 +305,13 @@ def get_time_ms() -> int: self.ten_env.log_warn( f"On flushed audio delta {message.response_id} {message.item_id} {message.content_index}") continue - item_id = message.item_id + if item_id != message.item_id: + item_id = message.item_id + self.first_token_times.append(time.time() - self.input_end) content_index = message.content_index self._on_audio_delta(message.delta) + case ResponseAudioDone(): + self.completion_times.append(time.time() - self.input_end) case InputAudioBufferSpeechStarted(): self.ten_env.log_info( f"On server listening, in response {response_id}, last item {item_id}") @@ -313,6 +331,8 @@ def get_time_ms() -> int: flushed.add(response_id) item_id = "" case InputAudioBufferSpeechStopped(): + # Only for server vad + self.input_end = time.time() relative_start_ms = get_time_ms() - message.audio_end_ms self.ten_env.log_info( f"On server stop listening, {message.audio_end_ms}, relative {relative_start_ms}") @@ -340,6 +360,17 @@ def get_time_ms() -> int: # clear so that new session can be triggered self.connected = False self.remote_stream_id = 0 + + if not self.stopped: + await self.conn.close() + await asyncio.sleep(0.5) + self.ten_env.log_info("Reconnect") + + self.conn = RealtimeApiConnection( + ten_env=self.ten_env, + base_uri=self.config.base_uri, path=self.config.path, api_key=self.config.api_key, model=self.config.model, vendor=self.config.vendor) + + self.loop.create_task(self._loop()) async def _on_memory_expired(self, message: dict) -> None: self.ten_env.log_info(f"Memory expired: {message}") @@ -590,3 +621,36 @@ async def _flush(self) -> None: await self.ten_env.send_cmd(c) except: self.ten_env.log_error(f"Error flush") + + async def _update_usage(self, usage: dict) -> None: + self.total_usage.completion_tokens += usage.get("output_tokens") + self.total_usage.prompt_tokens += usage.get("input_tokens") + self.total_usage.total_tokens += usage.get("total_tokens") + if not self.total_usage.completion_tokens_details: + self.total_usage.completion_tokens_details = LLMCompletionTokensDetails() + if not self.total_usage.prompt_tokens_details: + self.total_usage.prompt_tokens_details = LLMPromptTokensDetails() + + if usage.get("output_token_details"): + self.total_usage.completion_tokens_details.accepted_prediction_tokens += usage["output_token_details"].get("text_tokens") + self.total_usage.completion_tokens_details.audio_tokens += usage["output_token_details"].get("audio_tokens") + + if usage.get("input_token_details:"): + self.total_usage.prompt_tokens_details.audio_tokens += usage["input_token_details"].get("audio_tokens") + self.total_usage.prompt_tokens_details.cached_tokens += usage["input_token_details"].get("cached_tokens") + self.total_usage.prompt_tokens_details.text_tokens += usage["input_token_details"].get("text_tokens") + + self.ten_env.log_info(f"total usage: {self.total_usage}") + + data = Data.create("llm_stat") + data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump())) + if self.connect_times and self.completion_times and self.first_token_times: + data.set_property_from_json("latency", json.dumps({ + "connection_latency_95": np.percentile(self.connect_times, 95), + "completion_latency_95": np.percentile(self.completion_times, 95), + "first_token_latency_95": np.percentile(self.first_token_times, 95), + "connection_latency_99": np.percentile(self.connect_times, 99), + "completion_latency_99": np.percentile(self.completion_times, 99), + "first_token_latency_99": np.percentile(self.first_token_times, 99) + })) + self.ten_env.send_data(data)