Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: greeting on user joined #313

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions agents/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,26 @@
}
]
}
],
"cmd": [
{
"name": "on_user_joined",
wangyoucao577 marked this conversation as resolved.
Show resolved Hide resolved
"dest": [
{
"extension_group": "chatgpt",
"extension": "openai_chatgpt"
}
]
},
{
"name": "on_user_left",
"dest": [
{
"extension_group": "chatgpt",
"extension": "openai_chatgpt"
}
]
}
]
},
{
Expand Down Expand Up @@ -1495,6 +1515,26 @@
}
]
}
],
"cmd": [
{
"name": "on_user_joined",
"dest": [
{
"extension_group": "chatgpt",
"extension": "openai_chatgpt"
}
]
},
{
"name": "on_user_left",
"dest": [
{
"extension_group": "chatgpt",
"extension": "openai_chatgpt"
}
]
}
]
},
{
Expand Down
122 changes: 79 additions & 43 deletions agents/ten_packages/extension/openai_chatgpt_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from .log import logger

CMD_IN_FLUSH = "flush"
CMD_IN_ON_USER_JOINED = "on_user_joined"
CMD_IN_ON_USER_LEFT = "on_user_left"
CMD_OUT_FLUSH = "flush"
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
Expand All @@ -51,6 +53,7 @@
TASK_TYPE_CHAT_COMPLETION = "chat_completion"
TASK_TYPE_CHAT_COMPLETION_WITH_VISION = "chat_completion_with_vision"


class OpenAIChatGPTExtension(Extension):
memory = []
max_memory_length = 10
Expand Down Expand Up @@ -86,6 +89,7 @@ def on_start(self, ten_env: TenEnv) -> None:
logger.info("on_start")

self.loop = asyncio.new_event_loop()

def start_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
Expand All @@ -97,50 +101,56 @@ def start_loop():
openai_chatgpt_config = OpenAIChatGPTConfig.default_config()

# Mandatory properties
openai_chatgpt_config.base_url = get_property_string(ten_env, PROPERTY_BASE_URL) or openai_chatgpt_config.base_url
openai_chatgpt_config.api_key = get_property_string(ten_env, PROPERTY_API_KEY)
openai_chatgpt_config.base_url = get_property_string(
ten_env, PROPERTY_BASE_URL) or openai_chatgpt_config.base_url
openai_chatgpt_config.api_key = get_property_string(
ten_env, PROPERTY_API_KEY)
if not openai_chatgpt_config.api_key:
logger.info(f"API key is missing, exiting on_start")
return

# Optional properties
openai_chatgpt_config.model = get_property_string(ten_env, PROPERTY_MODEL) or openai_chatgpt_config.model
openai_chatgpt_config.prompt = get_property_string(ten_env, PROPERTY_PROMPT) or openai_chatgpt_config.prompt
openai_chatgpt_config.frequency_penalty = get_property_float(ten_env, PROPERTY_FREQUENCY_PENALTY) or openai_chatgpt_config.frequency_penalty
openai_chatgpt_config.presence_penalty = get_property_float(ten_env, PROPERTY_PRESENCE_PENALTY) or openai_chatgpt_config.presence_penalty
openai_chatgpt_config.temperature = get_property_float(ten_env, PROPERTY_TEMPERATURE) or openai_chatgpt_config.temperature
openai_chatgpt_config.top_p = get_property_float(ten_env, PROPERTY_TOP_P) or openai_chatgpt_config.top_p
openai_chatgpt_config.max_tokens = get_property_int(ten_env, PROPERTY_MAX_TOKENS) or openai_chatgpt_config.max_tokens
openai_chatgpt_config.proxy_url = get_property_string(ten_env, PROPERTY_PROXY_URL) or openai_chatgpt_config.proxy_url
openai_chatgpt_config.model = get_property_string(
ten_env, PROPERTY_MODEL) or openai_chatgpt_config.model
openai_chatgpt_config.prompt = get_property_string(
ten_env, PROPERTY_PROMPT) or openai_chatgpt_config.prompt
openai_chatgpt_config.frequency_penalty = get_property_float(
ten_env, PROPERTY_FREQUENCY_PENALTY) or openai_chatgpt_config.frequency_penalty
openai_chatgpt_config.presence_penalty = get_property_float(
ten_env, PROPERTY_PRESENCE_PENALTY) or openai_chatgpt_config.presence_penalty
openai_chatgpt_config.temperature = get_property_float(
ten_env, PROPERTY_TEMPERATURE) or openai_chatgpt_config.temperature
openai_chatgpt_config.top_p = get_property_float(
ten_env, PROPERTY_TOP_P) or openai_chatgpt_config.top_p
openai_chatgpt_config.max_tokens = get_property_int(
ten_env, PROPERTY_MAX_TOKENS) or openai_chatgpt_config.max_tokens
openai_chatgpt_config.proxy_url = get_property_string(
ten_env, PROPERTY_PROXY_URL) or openai_chatgpt_config.proxy_url

# Properties that don't affect openai_chatgpt_config
greeting = get_property_string(ten_env, PROPERTY_GREETING)
self.greeting = get_property_string(ten_env, PROPERTY_GREETING)
self.enable_tools = get_property_bool(ten_env, PROPERTY_ENABLE_TOOLS)
self.max_memory_length = get_property_int(ten_env, PROPERTY_MAX_MEMORY_LENGTH)
checking_vision_text_items_str = get_property_string(ten_env, PROPERTY_CHECKING_VISION_TEXT_ITEMS)
self.max_memory_length = get_property_int(
ten_env, PROPERTY_MAX_MEMORY_LENGTH)
checking_vision_text_items_str = get_property_string(
ten_env, PROPERTY_CHECKING_VISION_TEXT_ITEMS)
if checking_vision_text_items_str:
try:
self.checking_vision_text_items = json.loads(checking_vision_text_items_str)
self.checking_vision_text_items = json.loads(
checking_vision_text_items_str)
except Exception as err:
logger.info(f"Error parsing {PROPERTY_CHECKING_VISION_TEXT_ITEMS}: {err}")
logger.info(
f"Error parsing {PROPERTY_CHECKING_VISION_TEXT_ITEMS}: {err}")
self.users_count = 0

# Create instance
try:
self.openai_chatgpt = OpenAIChatGPT(openai_chatgpt_config)
logger.info(f"initialized with max_tokens: {openai_chatgpt_config.max_tokens}, model: {openai_chatgpt_config.model}")
logger.info(
f"initialized with max_tokens: {openai_chatgpt_config.max_tokens}, model: {openai_chatgpt_config.model}")
except Exception as err:
logger.info(f"Failed to initialize OpenAIChatGPT: {err}")

# Send greeting if available
if greeting:
try:
output_data = Data.create("text_data")
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting)
output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True)
ten_env.send_data(output_data)
logger.info(f"Greeting [{greeting}] sent")
except Exception as err:
logger.info(f"Failed to send greeting [{greeting}]: {err}")
ten_env.on_start_done()

def on_stop(self, ten_env: TenEnv) -> None:
Expand All @@ -155,19 +165,38 @@ def on_deinit(self, ten_env: TenEnv) -> None:
ten_env.on_deinit_done()

def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
logger.info(f"on_cmd json: {cmd.to_json()}")

cmd_name = cmd.get_name()
logger.info(f"on_cmd name: {cmd_name}")

if cmd_name == CMD_IN_FLUSH:
asyncio.run_coroutine_threadsafe(self._flush_queue(), self.loop)
ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH), None)
logger.info("on_cmd sent flush")
status_code, detail = StatusCode.OK, "success"
elif cmd_name == CMD_IN_ON_USER_JOINED:
self.users_count += 1
# Send greeting when first user joined
if self.greeting and self.users_count == 1:
try:
output_data = Data.create("text_data")
output_data.set_property_string(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, self.greeting)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True)
ten_env.send_data(output_data)
logger.info(f"Greeting [{self.greeting}] sent")
except Exception as err:
logger.info(
f"Failed to send greeting [{self.greeting}]: {err}")

status_code, detail = StatusCode.OK, "success"
elif cmd_name == CMD_IN_ON_USER_LEFT:
self.users_count -= 1
status_code, detail = StatusCode.OK, "success"
else:
logger.info(f"on_cmd unknown cmd: {cmd_name}")
status_code, detail = StatusCode.ERROR, "unknown cmd"

cmd_result = CmdResult.create(status_code)
cmd_result.set_property_string("detail", detail)
ten_env.return_result(cmd_result, cmd)
Expand All @@ -187,7 +216,8 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
logger.info(f"OnData input text: [{input_text}]")

# Start an asynchronous task for handling chat completion
asyncio.run_coroutine_threadsafe(self.queue.put([TASK_TYPE_CHAT_COMPLETION, input_text]), self.loop)
asyncio.run_coroutine_threadsafe(self.queue.put(
[TASK_TYPE_CHAT_COMPLETION, input_text]), self.loop)

def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
# TODO: process pcm frame
Expand All @@ -207,7 +237,8 @@ async def _process_queue(self, ten_env: TenEnv):
[task_type, message] = await self.queue.get()
try:
# Create a new task for the new message
self.current_task = asyncio.create_task(self._run_chatflow(ten_env, task_type, message, self.memory))
self.current_task = asyncio.create_task(
self._run_chatflow(ten_env, task_type, message, self.memory))
await self.current_task # Wait for the current task to finish or be cancelled
except asyncio.CancelledError:
logger.info(f"Task cancelled: {message}")
Expand All @@ -222,7 +253,7 @@ async def _flush_queue(self):
logger.info("Cancelling the current task during flush.")
self.current_task.cancel()

async def _run_chatflow(self, ten_env: TenEnv, task_type:str, input_text: str, memory):
async def _run_chatflow(self, ten_env: TenEnv, task_type: str, input_text: str, memory):
"""Run the chatflow asynchronously."""
memory_cache = []
try:
Expand All @@ -233,14 +264,17 @@ async def _run_chatflow(self, ten_env: TenEnv, task_type:str, input_text: str, m
# Prepare the message and tools based on the task type
if task_type == TASK_TYPE_CHAT_COMPLETION:
message = {"role": "user", "content": input_text}
memory_cache = memory_cache + [message, {"role": "assistant", "content": ""}]
memory_cache = memory_cache + \
[message, {"role": "assistant", "content": ""}]
tools = self.available_tools if self.enable_tools else None
elif task_type == TASK_TYPE_CHAT_COMPLETION_WITH_VISION:
message = {"role": "user", "content": input_text}
memory_cache = memory_cache + [message, {"role": "assistant", "content": ""}]
memory_cache = memory_cache + \
[message, {"role": "assistant", "content": ""}]
tools = self.available_tools if self.enable_tools else None
if self.image_data is not None:
url = rgb2base64jpeg(self.image_data, self.image_width, self.image_height)
url = rgb2base64jpeg(
self.image_data, self.image_width, self.image_height)
message = {
"role": "user",
"content": [
Expand All @@ -250,7 +284,6 @@ async def _run_chatflow(self, ten_env: TenEnv, task_type:str, input_text: str, m
}
logger.info(f"msg with vision data: {message}")


self.sentence_fragment = ""

# Create an asyncio.Event to signal when content is finished
Expand All @@ -263,17 +296,18 @@ async def handle_tool_call(tool_call):
# Append the vision image to the last assistant message
await self.queue.put([TASK_TYPE_CHAT_COMPLETION_WITH_VISION, input_text], True)

async def handle_content_update(content:str):
async def handle_content_update(content: str):
# Append the content to the last assistant message
for item in reversed(memory_cache):
if item.get('role') == 'assistant':
item['content'] = item['content'] + content
break
sentences, self.sentence_fragment = parse_sentences(self.sentence_fragment, content)
sentences, self.sentence_fragment = parse_sentences(
self.sentence_fragment, content)
for s in sentences:
self._send_data(ten_env, s, False)

async def handle_content_finished(full_content:str):
async def handle_content_finished(full_content: str):
content_finished_event.set()

listener = AsyncEventEmitter()
Expand All @@ -289,22 +323,24 @@ async def handle_content_finished(full_content:str):
except asyncio.CancelledError:
logger.info(f"Task cancelled: {input_text}")
except Exception as e:
logger.error(f"Error in chat_completion: {traceback.format_exc()} for input text: {input_text}")
logger.error(
f"Error in chat_completion: {traceback.format_exc()} for input text: {input_text}")
finally:
self._send_data(ten_env, "", True)
# always append the memory
for m in memory_cache:
self._append_memory(m)

def _append_memory(self, message:str):
def _append_memory(self, message: str):
if len(self.memory) > self.max_memory_length:
self.memory.pop(0)
self.memory.append(message)

def _send_data(self, ten_env: TenEnv, sentence: str, end_of_segment: bool):
try:
output_data = Data.create("text_data")
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
output_data.set_property_string(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, end_of_segment
)
Expand All @@ -315,4 +351,4 @@ def _send_data(self, ten_env: TenEnv, sentence: str, end_of_segment: bool):
except Exception as err:
logger.info(
f"send sentence [{sentence}] failed, err: {err}"
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@
"cmd_in": [
{
"name": "flush"
},
{
"name": "on_user_joined"
},
{
"name": "on_user_left"
}
],
"cmd_out": [
Expand Down