Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact openai v2v #438

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions agents/examples/demo/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10
"max_history": 10
}
},
{
Expand Down Expand Up @@ -848,7 +848,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10
"max_history": 10
}
},
{
Expand Down
10 changes: 5 additions & 5 deletions agents/examples/experimental/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10,
"max_history": 10,
"vendor": "azure",
"base_uri": "${env:AZURE_OPENAI_REALTIME_BASE_URI}",
"path": "/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview",
"system_message": ""
"prompt": ""
}
},
{
Expand Down Expand Up @@ -2444,7 +2444,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10
"max_history": 10
}
},
{
Expand Down Expand Up @@ -2566,7 +2566,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10
"max_history": 10
}
},
{
Expand Down Expand Up @@ -2724,7 +2724,7 @@
"language": "en-US",
"server_vad": true,
"dump": true,
"history": 10,
"max_history": 10,
"enable_storage": true
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
from fastapi import Depends, FastAPI, HTTPException, Request
import asyncio

# Enable Pydantic debug mode
from pydantic import BaseConfig

BaseConfig.debug = True

# Set up logging
logging.config.dictConfig({
"version": 1,
Expand Down
222 changes: 156 additions & 66 deletions agents/ten_packages/extension/glue_python_async/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import traceback
import aiohttp
import json
import time
import re

from datetime import datetime
import numpy as np
from typing import List, Any, AsyncGenerator
from dataclasses import dataclass
from dataclasses import dataclass, field
from pydantic import BaseModel

from ten import (
Expand All @@ -23,7 +25,7 @@
Data,
)

from ten_ai_base import BaseConfig, ChatMemory
from ten_ai_base import BaseConfig, ChatMemory, LLMUsage, LLMCompletionTokensDetails, LLMPromptTokensDetails, EVENT_MEMORY_APPENDED
from ten_ai_base.llm import AsyncLLMBaseExtension, LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata
from ten_ai_base.types import LLMChatCompletionUserMessageParam, LLMToolResult

Expand Down Expand Up @@ -84,27 +86,9 @@ class Choice(BaseModel):
index: int
finish_reason: str | None

class CompletionTokensDetails(BaseModel):
accepted_prediction_tokens: int = 0
audio_tokens: int = 0
reasoning_tokens: int = 0
rejected_prediction_tokens: int = 0

class PromptTokensDetails(BaseModel):
audio_tokens: int = 0
cached_tokens: int = 0

class Usage(BaseModel):
completion_tokens: int = 0
prompt_tokens: int = 0
total_tokens: int = 0

completion_tokens_details: CompletionTokensDetails | None = None
prompt_tokens_details: PromptTokensDetails | None = None

class ResponseChunk(BaseModel):
choices: List[Choice]
usage: Usage | None = None
usage: LLMUsage | None = None

@dataclass
class GlueConfig(BaseConfig):
Expand All @@ -113,17 +97,29 @@ class GlueConfig(BaseConfig):
prompt: str = ""
max_history: int = 10
greeting: str = ""
failure_info: str = ""
modalities: List[str] = field(default_factory=lambda: ["text"])
rtm_enabled: bool = True
ssml_enabled: bool = False
context_enabled: bool = False
extra_context: dict = field(default_factory=dict)
enable_storage: bool = False

class AsyncGlueExtension(AsyncLLMBaseExtension):
config : GlueConfig = None
sentence_fragment: str = ""
ten_env: AsyncTenEnv = None
loop: asyncio.AbstractEventLoop = None
stopped: bool = False
memory: ChatMemory = None
total_usage: Usage = Usage()
total_usage: LLMUsage = LLMUsage()
users_count = 0

completion_times = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for metrics maybe create a utility class?

we can use like

metrics.mark("connect_times")
## after some seconds
elapse = metrics.elapse("connect_times")

connect_times = []
first_token_times = []

remote_stream_id: int = 999 # TODO

async def on_init(self, ten_env: AsyncTenEnv) -> None:
await super().on_init(ten_env)
ten_env.log_debug("on_init")
Expand All @@ -139,6 +135,21 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None:

self.memory = ChatMemory(self.config.max_history)

if self.config.enable_storage:
result = await ten_env.send_cmd(Cmd.create("retrieve"))
if result.get_status_code() == StatusCode.OK:
try:
history = json.loads(result.get_property_string("response"))
for i in history:
self.memory.put(i)
ten_env.log_info(f"on retrieve context {history}")
except Exception as e:
ten_env.log_error("Failed to handle retrieve result {e}")
else:
ten_env.log_warn("Failed to retrieve content")

self.memory.on(EVENT_MEMORY_APPENDED, self._on_memory_appended)

self.ten_env = ten_env

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
Expand Down Expand Up @@ -187,7 +198,21 @@ async def on_data_chat_completion(self, ten_env: AsyncTenEnv, **kargs: LLMDataCo
messages = []
if self.config.prompt:
messages.append({"role": "system", "content": self.config.prompt})
messages.extend(self.memory.get())

history = self.memory.get()
while history:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while len(history) > 0?

if history[0].get("role") == "tool":
history = history[1:]
continue
if history[0].get("role") == "assistant" and history[0].get("tool_calls"):
history = history[1:]
continue

# Skip the first tool role
plutoless marked this conversation as resolved.
Show resolved Hide resolved
break

messages.extend(history)

if not input:
ten_env.log_warn("No message in data")
else:
Expand Down Expand Up @@ -220,6 +245,10 @@ def tool_dict(tool: LLMToolMetadata):
json["function"]["parameters"]["required"].append(param.name)

return json

def trim_xml(input_string):
return re.sub(r'<[^>]+>', '', input_string).strip()

tools = []
for tool in self.available_tools:
tools.append(tool_dict(tool))
Expand All @@ -229,16 +258,25 @@ def tool_dict(tool: LLMToolMetadata):
calls = {}

sentences = []
start_time = time.time()
first_token_time = None
response = self._stream_chat(messages=messages, tools=tools)
async for message in response:
self.ten_env.log_info(f"content: {message}")
self.ten_env.log_debug(f"content: {message}")
# TODO: handle tool call
try:
c = ResponseChunk(**message)
if c.choices:
if c.choices[0].delta.content:
total_output += c.choices[0].delta.content
sentences, sentence_fragment = parse_sentences(sentence_fragment, c.choices[0].delta.content)
if first_token_time is None:
first_token_time = time.time()
self.first_token_times.append(first_token_time - start_time)

content = c.choices[0].delta.content
if self.config.ssml_enabled and content.startswith("<speak>"):
content = trim_xml(content)
total_output += content
sentences, sentence_fragment = parse_sentences(sentence_fragment, content)
for s in sentences:
await self._send_text(s)
if c.choices[0].delta.tool_calls:
Expand All @@ -252,10 +290,14 @@ def tool_dict(tool: LLMToolMetadata):
calls[call.index].function.arguments += call.function.arguments
if c.usage:
self.ten_env.log_info(f"usage: {c.usage}")
self._update_usage(c.usage)
await self._update_usage(c.usage)
except Exception as e:
self.ten_env.log_error(f"Failed to parse response: {message} {e}")
traceback.print_exc()
if sentence_fragment:
await self._send_text(sentence_fragment)
end_time = time.time()
self.completion_times.append(end_time - start_time)

if total_output:
self.memory.put({"role": "assistant", "content": total_output})
Expand Down Expand Up @@ -343,48 +385,67 @@ async def _send_text(self, text: str) -> None:
self.ten_env.send_data(data)

async def _stream_chat(self, messages: List[Any], tools: List[Any]) -> AsyncGenerator[dict, None]:
session = aiohttp.ClientSession()
try:
payload = {
"messages": messages,
"tools": tools,
"tools_choice": "auto" if tools else "none",
"model": "gpt-3.5-turbo",
"stream": True,
"stream_options": {"include_usage": True}
}
self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}")
headers = {
"Authorization": f"Bearer {self.config.token}",
"Content-Type": "application/json"
}

async with session.post(self.config.api_url, json=payload, headers=headers) as response:
if response.status != 200:
r = await response.json()
self.ten_env.log_error(f"Received unexpected status {r} from the server.")
return
async with aiohttp.ClientSession() as session:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe cache it to avoid overhead?

try:
payload = {
"messages": messages,
"tools": tools,
"tools_choice": "auto" if tools else "none",
"model": "gpt-3.5-turbo",
"stream": True,
"stream_options": {"include_usage": True},
"ssml_enabled": self.config.ssml_enabled
}
if self.config.context_enabled:
payload["context"] = {
**self.config.extra_context
}
self.ten_env.log_info(f"payload before sending: {json.dumps(payload)}")
headers = {
"Authorization": f"Bearer {self.config.token}",
"Content-Type": "application/json"
}

async for line in response.content:
if line:
l = line.decode('utf-8').strip()
if l.startswith("data:"):
content = l[5:].strip()
if content == "[DONE]":
break
self.ten_env.log_info(f"content: {content}")
yield json.loads(content)
except Exception as e:
self.ten_env.log_error(f"Failed to handle {e}")
finally:
await session.close()
session = None
start_time = time.time()
async with session.post(self.config.api_url, json=payload, headers=headers) as response:
if response.status != 200:
r = await response.json()
self.ten_env.log_error(f"Received unexpected status {r} from the server.")
if self.config.failure_info:
await self._send_text(self.config.failure_info)
return
end_time = time.time()
self.connect_times.append(end_time - start_time)

async for line in response.content:
if line:
l = line.decode('utf-8').strip()
if l.startswith("data:"):
content = l[5:].strip()
if content == "[DONE]":
break
self.ten_env.log_debug(f"content: {content}")
yield json.loads(content)
except Exception as e:
traceback.print_exc()
self.ten_env.log_error(f"Failed to handle {e}")
finally:
await session.close()
session = None

async def _update_usage(self, usage: LLMUsage) -> None:
if not self.config.rtm_enabled:
return

async def _update_usage(self, usage: Usage) -> None:
self.total_usage.completion_tokens += usage.completion_tokens
self.total_usage.prompt_tokens += usage.prompt_tokens
self.total_usage.total_tokens += usage.total_tokens

if self.total_usage.completion_tokens_details is None:
self.total_usage.completion_tokens_details = LLMCompletionTokensDetails()
if self.total_usage.prompt_tokens_details is None:
self.total_usage.prompt_tokens_details = LLMPromptTokensDetails()

if usage.completion_tokens_details:
self.total_usage.completion_tokens_details.accepted_prediction_tokens += usage.completion_tokens_details.accepted_prediction_tokens
self.total_usage.completion_tokens_details.audio_tokens += usage.completion_tokens_details.audio_tokens
Expand All @@ -395,4 +456,33 @@ async def _update_usage(self, usage: Usage) -> None:
self.total_usage.prompt_tokens_details.audio_tokens += usage.prompt_tokens_details.audio_tokens
self.total_usage.prompt_tokens_details.cached_tokens += usage.prompt_tokens_details.cached_tokens

self.ten_env.log_info(f"total usage: {self.total_usage}")
self.ten_env.log_info(f"total usage: {self.total_usage}")

data = Data.create("llm_stat")
data.set_property_from_json("usage", json.dumps(self.total_usage.model_dump()))
if self.connect_times and self.completion_times and self.first_token_times:
data.set_property_from_json("latency", json.dumps({
"connection_latency_95": np.percentile(self.connect_times, 95),
"completion_latency_95": np.percentile(self.completion_times, 95),
"first_token_latency_95": np.percentile(self.first_token_times, 95),
"connection_latency_99": np.percentile(self.connect_times, 99),
"completion_latency_99": np.percentile(self.completion_times, 99),
"first_token_latency_99": np.percentile(self.first_token_times, 99)
}))
self.ten_env.send_data(data)

async def _on_memory_appended(self, message: dict) -> None:
self.ten_env.log_info(f"Memory appended: {message}")
if not self.config.enable_storage:
return

role = message.get("role")
stream_id = self.remote_stream_id if role == "user" else 0
try:
d = Data.create("append")
d.set_property_string("text", message.get("content"))
d.set_property_string("role", role)
d.set_property_int("stream_id", stream_id)
self.ten_env.send_data(d)
except Exception as e:
self.ten_env.log_error(f"Error send append_context data {message} {e}")
Loading