Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ChatMemory and BaseConfig #375

Merged
merged 7 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agents/bin/start
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set -e

cd "$(dirname "${BASH_SOURCE[0]}")/.."

export PYTHONPATH=/app/agents/ten_packages/system/ten_ai_base/interface:$PYTHONPATH
export PYTHONPATH=$(pwd)/ten_packages/system/ten_ai_base/interface:$PYTHONPATH
export LD_LIBRARY_PATH=$(pwd)/ten_packages/system/agora_rtc_sdk/lib:$(pwd)/ten_packages/system/azure_speech_sdk/lib

exec bin/worker "$@"
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,25 @@
#
from ten import AsyncTenEnv
from ten_ai_base import (
AsyncLLMBaseExtension, LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata
AsyncLLMBaseExtension, LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata, BaseConfig
)
from dataclasses import dataclass


@dataclass
class DefaultAsyncLLMConfig(BaseConfig):
model: str = ""
# TODO: add extra config fields here


class DefaultAsyncLLMExtension(AsyncLLMBaseExtension):
async def on_start(self, ten_env: AsyncTenEnv) -> None:
await super().on_start(ten_env)

# initialize configuration
self.config = DefaultAsyncLLMConfig.create(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")

"""Implement this method to construct and start your resources."""
ten_env.log_debug("TODO: on_start")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
]
},
"api": {
"property": {},
"property": {
"model": {
"type": "string"
}
},
"cmd_in": [
{
"name": "tool_register",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,26 @@
TenEnv,
AsyncTenEnv,
)
from ten_ai_base import AsyncLLMToolBaseExtension, LLMToolMetadata, LLMToolResult
from ten_ai_base import (
AsyncLLMToolBaseExtension, LLMToolMetadata, LLMToolResult, BaseConfig
)
from dataclasses import dataclass


@dataclass
class DefaultAsyncLLMToolConfig(BaseConfig):
# TODO: add extra config fields here
pass


class DefaultAsyncLLMToolExtension(AsyncLLMToolBaseExtension):
async def on_start(self, ten_env: AsyncTenEnv) -> None:
await super().on_start(ten_env)

# initialize configuration
self.config = DefaultAsyncLLMToolConfig.create(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")

"""Implement this method to construct and start your resources."""
ten_env.log_debug("TODO: on_start")

Expand All @@ -25,6 +38,7 @@ async def on_stop(self, ten_env: AsyncTenEnv) -> None:

def get_tool_metadata(self, ten_env: TenEnv) -> list[LLMToolMetadata]:
ten_env.log_debug("TODO: get_tool_metadata")
return []

async def run_tool(self, ten_env: AsyncTenEnv, name: str, args: dict) -> LLMToolResult:
ten_env.log_debug(f"TODO: run_tool {name} {args}")
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from .types import LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata, LLMToolResult
from .llm import AsyncLLMBaseExtension
from .llm_tool import AsyncLLMToolBaseExtension
from .chat_memory import ChatMemory
from .helper import AsyncQueue, AsyncEventEmitter
from .config import BaseConfig

# Specify what should be imported when a user imports * from the
# ten_ai_base package.
Expand All @@ -16,5 +19,9 @@
"LLMCallCompletionArgs",
"LLMDataCompletionArgs",
"AsyncLLMBaseExtension",
"AsyncLLMToolBaseExtension"
"AsyncLLMToolBaseExtension",
"ChatMemory",
"AsyncQueue",
"AsyncEventEmitter",
"BaseConfig"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
import threading


class ChatMemory:
def __init__(self, max_history_length):
self.max_history_length = max_history_length
self.history = []
self.mutex = threading.Lock() # TODO: no need lock for asyncio

def put(self, message):
with self.mutex:
self.history.append(message)

while True:
history_count = len(self.history)
if history_count > 0 and history_count > self.max_history_length:
self.history.pop(0)
continue
if history_count > 0 and self.history[0]["role"] == "assistant":
# we cannot have an assistant message at the start of the chat history
# if after removal of the first, we have an assistant message,
# we need to remove the assistant message too
self.history.pop(0)
continue
break

def get(self):
with self.mutex:
return self.history

def count(self):
with self.mutex:
return len(self.history)

def clear(self):
with self.mutex:
self.history = []
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from dataclasses import dataclass, fields
import builtins
from typing import TypeVar, Type
from ten import TenEnv

T = TypeVar('T', bound='BaseConfig')


@dataclass
class BaseConfig:
"""
Base class for implementing configuration.
Extra configuration fields can be added in inherited class.
"""

@classmethod
def create(cls: Type[T], ten_env: TenEnv) -> T:
c = cls()
c._init(ten_env)
return c

def _init(obj, ten_env: TenEnv):
"""
Get property from ten_env to initialize the dataclass config.
"""
for field in fields(obj):
# TODO: 'is_property_exist' has a bug that can not be used in async extension currently, use it instead of try .. except once fixed
# if not ten_env.is_property_exist(field.name):
# continue
try:
match field.type:
case builtins.str:
val = ten_env.get_property_string(field.name)
if val:
setattr(obj, field.name, val)
case builtins.int:
val = ten_env.get_property_int(field.name)
setattr(obj, field.name, val)
case builtins.bool:
val = ten_env.get_property_bool(field.name)
setattr(obj, field.name, val)
case builtins.float:
val = ten_env.get_property_float(field.name)
setattr(obj, field.name, val)
case _:
pass
except Exception as e:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ def __init__(self, name: str):
self.hit_default_cmd = False

async def on_init(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_init")
await super().on_init(ten_env)

async def on_start(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_start")
await super().on_start(ten_env)

self.loop = asyncio.get_event_loop()
self.loop.create_task(self._process_queue(ten_env))

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_stop")
await super().on_stop(ten_env)

async def on_deinit(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_deinit")
await super().on_deinit(ten_env)

async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

class AsyncLLMToolBaseExtension(AsyncExtension, ABC):
async def on_start(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_start")
tools = self.get_tool_metadata()
await super().on_start(ten_env)

tools = self.get_tool_metadata(ten_env)
for tool in tools:
ten_env.log_info(f"tool: {tool}")
c: Cmd = Cmd.create(CMD_TOOL_REGISTER)
Expand All @@ -27,10 +28,7 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_info(f"tool registered, {tool}")

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_stop")

async def on_deinit(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_deinit")
await super().on_stop(ten_env)

async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
Expand Down