Skip to content

Commit

Permalink
feat: add ChatMemory and BaseConfig (#375)
Browse files Browse the repository at this point in the history
* feat: add chat memory

* fix: path

* fix: lifecycle

* feat: add base config

* fix: missed config.py

* fix: syntax

* feat: simplify usage
  • Loading branch information
wangyoucao577 authored Nov 2, 2024
1 parent 928c24f commit 05d97ec
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 15 deletions.
2 changes: 1 addition & 1 deletion agents/bin/start
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set -e

cd "$(dirname "${BASH_SOURCE[0]}")/.."

export PYTHONPATH=/app/agents/ten_packages/system/ten_ai_base/interface:$PYTHONPATH
export PYTHONPATH=$(pwd)/ten_packages/system/ten_ai_base/interface:$PYTHONPATH
export LD_LIBRARY_PATH=$(pwd)/ten_packages/system/agora_rtc_sdk/lib:$(pwd)/ten_packages/system/azure_speech_sdk/lib

exec bin/worker "$@"
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,25 @@
#
from ten import AsyncTenEnv
from ten_ai_base import (
AsyncLLMBaseExtension, LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata
AsyncLLMBaseExtension, LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata, BaseConfig
)
from dataclasses import dataclass


@dataclass
class DefaultAsyncLLMConfig(BaseConfig):
model: str = ""
# TODO: add extra config fields here


class DefaultAsyncLLMExtension(AsyncLLMBaseExtension):
async def on_start(self, ten_env: AsyncTenEnv) -> None:
await super().on_start(ten_env)

# initialize configuration
self.config = DefaultAsyncLLMConfig.create(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")

"""Implement this method to construct and start your resources."""
ten_env.log_debug("TODO: on_start")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
]
},
"api": {
"property": {},
"property": {
"model": {
"type": "string"
}
},
"cmd_in": [
{
"name": "tool_register",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,26 @@
TenEnv,
AsyncTenEnv,
)
from ten_ai_base import AsyncLLMToolBaseExtension, LLMToolMetadata, LLMToolResult
from ten_ai_base import (
AsyncLLMToolBaseExtension, LLMToolMetadata, LLMToolResult, BaseConfig
)
from dataclasses import dataclass


@dataclass
class DefaultAsyncLLMToolConfig(BaseConfig):
# TODO: add extra config fields here
pass


class DefaultAsyncLLMToolExtension(AsyncLLMToolBaseExtension):
async def on_start(self, ten_env: AsyncTenEnv) -> None:
await super().on_start(ten_env)

# initialize configuration
self.config = DefaultAsyncLLMToolConfig.create(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")

"""Implement this method to construct and start your resources."""
ten_env.log_debug("TODO: on_start")

Expand All @@ -25,6 +38,7 @@ async def on_stop(self, ten_env: AsyncTenEnv) -> None:

def get_tool_metadata(self, ten_env: TenEnv) -> list[LLMToolMetadata]:
ten_env.log_debug("TODO: get_tool_metadata")
return []

async def run_tool(self, ten_env: AsyncTenEnv, name: str, args: dict) -> LLMToolResult:
ten_env.log_debug(f"TODO: run_tool {name} {args}")
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from .types import LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata, LLMToolResult
from .llm import AsyncLLMBaseExtension
from .llm_tool import AsyncLLMToolBaseExtension
from .chat_memory import ChatMemory
from .helper import AsyncQueue, AsyncEventEmitter
from .config import BaseConfig

# Specify what should be imported when a user imports * from the
# ten_ai_base package.
Expand All @@ -16,5 +19,9 @@
"LLMCallCompletionArgs",
"LLMDataCompletionArgs",
"AsyncLLMBaseExtension",
"AsyncLLMToolBaseExtension"
"AsyncLLMToolBaseExtension",
"ChatMemory",
"AsyncQueue",
"AsyncEventEmitter",
"BaseConfig"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
import threading


class ChatMemory:
def __init__(self, max_history_length):
self.max_history_length = max_history_length
self.history = []
self.mutex = threading.Lock() # TODO: no need lock for asyncio

def put(self, message):
with self.mutex:
self.history.append(message)

while True:
history_count = len(self.history)
if history_count > 0 and history_count > self.max_history_length:
self.history.pop(0)
continue
if history_count > 0 and self.history[0]["role"] == "assistant":
# we cannot have an assistant message at the start of the chat history
# if after removal of the first, we have an assistant message,
# we need to remove the assistant message too
self.history.pop(0)
continue
break

def get(self):
with self.mutex:
return self.history

def count(self):
with self.mutex:
return len(self.history)

def clear(self):
with self.mutex:
self.history = []
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from dataclasses import dataclass, fields
import builtins
from typing import TypeVar, Type
from ten import TenEnv

T = TypeVar('T', bound='BaseConfig')


@dataclass
class BaseConfig:
"""
Base class for implementing configuration.
Extra configuration fields can be added in inherited class.
"""

@classmethod
def create(cls: Type[T], ten_env: TenEnv) -> T:
c = cls()
c._init(ten_env)
return c

def _init(obj, ten_env: TenEnv):
"""
Get property from ten_env to initialize the dataclass config.
"""
for field in fields(obj):
# TODO: 'is_property_exist' has a bug that can not be used in async extension currently, use it instead of try .. except once fixed
# if not ten_env.is_property_exist(field.name):
# continue
try:
match field.type:
case builtins.str:
val = ten_env.get_property_string(field.name)
if val:
setattr(obj, field.name, val)
case builtins.int:
val = ten_env.get_property_int(field.name)
setattr(obj, field.name, val)
case builtins.bool:
val = ten_env.get_property_bool(field.name)
setattr(obj, field.name, val)
case builtins.float:
val = ten_env.get_property_float(field.name)
setattr(obj, field.name, val)
case _:
pass
except Exception as e:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ def __init__(self, name: str):
self.hit_default_cmd = False

async def on_init(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_init")
await super().on_init(ten_env)

async def on_start(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_start")
await super().on_start(ten_env)

self.loop = asyncio.get_event_loop()
self.loop.create_task(self._process_queue(ten_env))

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_stop")
await super().on_stop(ten_env)

async def on_deinit(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_deinit")
await super().on_deinit(ten_env)

async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

class AsyncLLMToolBaseExtension(AsyncExtension, ABC):
async def on_start(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_start")
tools = self.get_tool_metadata()
await super().on_start(ten_env)

tools = self.get_tool_metadata(ten_env)
for tool in tools:
ten_env.log_info(f"tool: {tool}")
c: Cmd = Cmd.create(CMD_TOOL_REGISTER)
Expand All @@ -27,10 +28,7 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_info(f"tool registered, {tool}")

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_stop")

async def on_deinit(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_debug("on_deinit")
await super().on_stop(ten_env)

async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
Expand Down

0 comments on commit 05d97ec

Please sign in to comment.