Skip to content

Commit

Permalink
reconnect opneai v2v
Browse files Browse the repository at this point in the history
  • Loading branch information
TomasBack2Future committed Nov 25, 2024
1 parent c19ec72 commit 635bba2
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 16 deletions.
14 changes: 14 additions & 0 deletions agents/ten_packages/extension/glue_python_async/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,20 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_info(f"config: {self.config}")

self.memory = ChatMemory(self.config.max_history)

if self.config.enable_storage:
result = await ten_env.send_cmd(Cmd.create("retrieve"))
if result.get_status_code() == StatusCode.OK:
try:
history = json.loads(result.get_property_string("response"))
for i in history:
self.memory.put(i)
ten_env.log_info(f"on retrieve context {history}")
except Exception as e:
ten_env.log_error("Failed to handle retrieve result {e}")
else:
ten_env.log_warn("Failed to retrieve content")

self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended)

self.ten_env = ten_env
Expand Down
96 changes: 80 additions & 16 deletions agents/ten_packages/extension/openai_v2v_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import asyncio
import base64
import traceback
import time
import numpy as np
from datetime import datetime
from typing import Iterable

Expand All @@ -23,7 +25,7 @@
from ten_ai_base.const import CMD_PROPERTY_RESULT, CMD_TOOL_CALL
from ten_ai_base.llm import AsyncLLMBaseExtension
from dataclasses import dataclass, field
from ten_ai_base import BaseConfig, ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, LLMUsage
from ten_ai_base import BaseConfig, ChatMemory, EVENT_MEMORY_EXPIRED, EVENT_MEMORY_APPENDED, LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails
from ten_ai_base.types import LLMToolMetadata, LLMToolResult, LLMChatCompletionContentPartParam
from .realtime.connection import RealtimeApiConnection
from .realtime.struct import *
Expand Down Expand Up @@ -72,7 +74,6 @@ class OpenAIRealtimeExtension(AsyncLLMBaseExtension):
connected: bool = False
buffer: bytearray = b''
memory: ChatMemory = None
retrieved: list = []
total_usage: LLMUsage = LLMUsage()
users_count = 0

Expand All @@ -88,6 +89,7 @@ class OpenAIRealtimeExtension(AsyncLLMBaseExtension):
buff: bytearray = b''
transcript: str = ""
ctx: dict = {}
input_end = time.time()

async def on_init(self, ten_env: AsyncTenEnv) -> None:
await super().on_init(ten_env)
Expand All @@ -108,21 +110,22 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None:

try:
self.memory = ChatMemory(self.config.max_history)
self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired)
self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended)

if self.config.enable_storage:
result = await ten_env.send_cmd(Cmd.create("retrieve"))
if result.get_status_code() == StatusCode.OK:
try:
history = json.loads(result.get_property_string("response"))
self.retrieved = history
for i in history:
self.memory.put(i)
ten_env.log_info(f"on retrieve context {history}")
except Exception as e:
ten_env.log_error("Failed to handle retrieve result {e}")
else:
ten_env.log_warn("Failed to retrieve content")


self.memory.on(EVENT_MEMORY_EXPIRED, self._on_memory_expired)
self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended)

self.conn = RealtimeApiConnection(
ten_env=ten_env,
Expand Down Expand Up @@ -155,6 +158,8 @@ async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) ->
self._dump_audio_if_need(frame_buf, Role.User)

await self._on_audio(frame_buf)
if not self.config.server_vad:
self.input_end = time.time()
except Exception as e:
traceback.print_exc()
self.ten_env.log_error(f"OpenAIV2VExtension on audio frame failed {e}")
Expand Down Expand Up @@ -197,7 +202,9 @@ def get_time_ms() -> int:
return current_time.microsecond // 1000

try:
start_time = time.time()
await self.conn.connect()
self.connect_times.append(time.time() - start_time)
item_id = "" # For truncate
response_id = ""
content_index = 0
Expand All @@ -216,13 +223,14 @@ def get_time_ms() -> int:
self.session = message.session
await self._update_session()

if self.retrieved:
for r in self.retrieved:
if r["role"] == MessageRole.User:
await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}])))
elif r["role"] == MessageRole.Assistant:
await self.conn.send_request(ItemCreate(item=AssistantMessageItemParam(content=[{"type": ContentType.InputText, "text": r["input"]}])))
self.ten_env.log_info(f"after append retrieve: {len(self.retrieved)}")
history = self.memory.get()
for h in history:
if h["role"] == "user":
await self.conn.send_request(ItemCreate(item=UserMessageItemParam(content=[{"type": ContentType.InputText, "text": h["content"]}])))
elif h["role"] == "assistant":
await self.conn.send_request(ItemCreate(item=AssistantMessageItemParam(content=[{"type": ContentType.InputText, "text": h["content"]}])))
self.ten_env.log_info(f"Finish send history {history}")
self.memory.clear()

if not self.connected:
self.connected = True
Expand All @@ -242,10 +250,12 @@ def get_time_ms() -> int:
case ResponseDone():
id = message.response.id
status = message.response.status
self.ten_env.log_info(
f"On response done {id} {status}")
if id == response_id:
response_id = ""
self.ten_env.log_info(
f"On response done {id} {status} {message.response.usage}")
if message.response.usage:
await self._update_usage(message.response.usage)
case ResponseAudioTranscriptDelta():
self.ten_env.log_info(
f"On response transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}")
Expand All @@ -261,6 +271,9 @@ def get_time_ms() -> int:
self.ten_env.log_warn(
f"On flushed text delta {message.response_id} {message.output_index} {message.content_index} {message.delta}")
continue
if item_id != message.item_id:
item_id = message.item_id
self.first_token_times.append(time.time() - self.input_end)
self._send_transcript(message.delta, Role.Assistant, False)
case ResponseAudioTranscriptDone():
self.ten_env.log_info(
Expand All @@ -279,6 +292,7 @@ def get_time_ms() -> int:
self.ten_env.log_warn(
f"On flushed text done {message.response_id}")
continue
self.completion_times.append(time.time() - self.input_end)
self.transcript = ""
self._send_transcript("", Role.Assistant, True)
case ResponseOutputItemDone():
Expand All @@ -291,9 +305,13 @@ def get_time_ms() -> int:
self.ten_env.log_warn(
f"On flushed audio delta {message.response_id} {message.item_id} {message.content_index}")
continue
item_id = message.item_id
if item_id != message.item_id:
item_id = message.item_id
self.first_token_times.append(time.time() - self.input_end)
content_index = message.content_index
self._on_audio_delta(message.delta)
case ResponseAudioDone():
self.completion_times.append(time.time() - self.input_end)
case InputAudioBufferSpeechStarted():
self.ten_env.log_info(
f"On server listening, in response {response_id}, last item {item_id}")
Expand All @@ -313,6 +331,8 @@ def get_time_ms() -> int:
flushed.add(response_id)
item_id = ""
case InputAudioBufferSpeechStopped():
# Only for server vad
self.input_end = time.time()
relative_start_ms = get_time_ms() - message.audio_end_ms
self.ten_env.log_info(
f"On server stop listening, {message.audio_end_ms}, relative {relative_start_ms}")
Expand Down Expand Up @@ -340,6 +360,17 @@ def get_time_ms() -> int:
# clear so that new session can be triggered
self.connected = False
self.remote_stream_id = 0

if not self.stopped:
await self.conn.close()
await asyncio.sleep(0.5)
self.ten_env.log_info("Reconnect")

self.conn = RealtimeApiConnection(
ten_env=self.ten_env,
base_uri=self.config.base_uri, path=self.config.path, api_key=self.config.api_key, model=self.config.model, vendor=self.config.vendor)

self.loop.create_task(self._loop())

async def _on_memory_expired(self, message: dict) -> None:
self.ten_env.log_info(f"Memory expired: {message}")
Expand Down Expand Up @@ -590,3 +621,36 @@ async def _flush(self) -> None:
await self.ten_env.send_cmd(c)
except:
self.ten_env.log_error(f"Error flush")

async def _update_usage(self, usage: dict) -> None:
self.total_usage.completion_tokens += usage.get("output_tokens")
self.total_usage.prompt_tokens += usage.get("input_tokens")
self.total_usage.total_tokens += usage.get("total_tokens")
if not self.total_usage.completion_tokens_details:
self.total_usage.completion_tokens_details = LLMCompletionTokensDetails()
if not self.total_usage.prompt_tokens_details:
self.total_usage.prompt_tokens_details = LLMPromptTokensDetails()

if usage.get("output_token_details"):
self.total_usage.completion_tokens_details.accepted_prediction_tokens += usage["output_token_details"].get("text_tokens")
self.total_usage.completion_tokens_details.audio_tokens += usage["output_token_details"].get("audio_tokens")

if usage.get("input_token_details:"):
self.total_usage.prompt_tokens_details.audio_tokens += usage["input_token_details"].get("audio_tokens")
self.total_usage.prompt_tokens_details.cached_tokens += usage["input_token_details"].get("cached_tokens")
self.total_usage.prompt_tokens_details.text_tokens += usage["input_token_details"].get("text_tokens")

self.ten_env.log_info(f"total usage: {self.total_usage}")

data = Data.create("llm_stat")
data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump()))
if self.connect_times and self.completion_times and self.first_token_times:
data.set_property_from_json("latency", json.dumps({
"connection_latency_95": np.percentile(self.connect_times, 95),
"completion_latency_95": np.percentile(self.completion_times, 95),
"first_token_latency_95": np.percentile(self.first_token_times, 95),
"connection_latency_99": np.percentile(self.connect_times, 99),
"completion_latency_99": np.percentile(self.completion_times, 99),
"first_token_latency_99": np.percentile(self.first_token_times, 99)
}))
self.ten_env.send_data(data)

0 comments on commit 635bba2

Please sign in to comment.