Skip to content

Commit

Permalink
fix: greeting on user joined
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyoucao577 authored Oct 9, 2024
1 parent 1b805ee commit eb874f9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 43 deletions.
22 changes: 22 additions & 0 deletions agents/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,17 @@
}
]
}
],
"cmd": [
{
"name": "on_user_joined",
"dest": [
{
"extension_group": "chatgpt",
"extension": "openai_chatgpt"
}
]
}
]
},
{
Expand Down Expand Up @@ -1495,6 +1506,17 @@
}
]
}
],
"cmd": [
{
"name": "on_user_joined",
"dest": [
{
"extension_group": "chatgpt",
"extension": "openai_chatgpt"
}
]
}
]
},
{
Expand Down
116 changes: 73 additions & 43 deletions agents/ten_packages/extension/openai_chatgpt_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .log import logger

CMD_IN_FLUSH = "flush"
CMD_IN_ON_USER_JOINED = "on_user_joined"
CMD_OUT_FLUSH = "flush"
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
Expand All @@ -51,6 +52,7 @@
TASK_TYPE_CHAT_COMPLETION = "chat_completion"
TASK_TYPE_CHAT_COMPLETION_WITH_VISION = "chat_completion_with_vision"


class OpenAIChatGPTExtension(Extension):
memory = []
max_memory_length = 10
Expand Down Expand Up @@ -86,6 +88,7 @@ def on_start(self, ten_env: TenEnv) -> None:
logger.info("on_start")

self.loop = asyncio.new_event_loop()

def start_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
Expand All @@ -97,50 +100,55 @@ def start_loop():
openai_chatgpt_config = OpenAIChatGPTConfig.default_config()

# Mandatory properties
openai_chatgpt_config.base_url = get_property_string(ten_env, PROPERTY_BASE_URL) or openai_chatgpt_config.base_url
openai_chatgpt_config.api_key = get_property_string(ten_env, PROPERTY_API_KEY)
openai_chatgpt_config.base_url = get_property_string(
ten_env, PROPERTY_BASE_URL) or openai_chatgpt_config.base_url
openai_chatgpt_config.api_key = get_property_string(
ten_env, PROPERTY_API_KEY)
if not openai_chatgpt_config.api_key:
logger.info(f"API key is missing, exiting on_start")
return

# Optional properties
openai_chatgpt_config.model = get_property_string(ten_env, PROPERTY_MODEL) or openai_chatgpt_config.model
openai_chatgpt_config.prompt = get_property_string(ten_env, PROPERTY_PROMPT) or openai_chatgpt_config.prompt
openai_chatgpt_config.frequency_penalty = get_property_float(ten_env, PROPERTY_FREQUENCY_PENALTY) or openai_chatgpt_config.frequency_penalty
openai_chatgpt_config.presence_penalty = get_property_float(ten_env, PROPERTY_PRESENCE_PENALTY) or openai_chatgpt_config.presence_penalty
openai_chatgpt_config.temperature = get_property_float(ten_env, PROPERTY_TEMPERATURE) or openai_chatgpt_config.temperature
openai_chatgpt_config.top_p = get_property_float(ten_env, PROPERTY_TOP_P) or openai_chatgpt_config.top_p
openai_chatgpt_config.max_tokens = get_property_int(ten_env, PROPERTY_MAX_TOKENS) or openai_chatgpt_config.max_tokens
openai_chatgpt_config.proxy_url = get_property_string(ten_env, PROPERTY_PROXY_URL) or openai_chatgpt_config.proxy_url
openai_chatgpt_config.model = get_property_string(
ten_env, PROPERTY_MODEL) or openai_chatgpt_config.model
openai_chatgpt_config.prompt = get_property_string(
ten_env, PROPERTY_PROMPT) or openai_chatgpt_config.prompt
openai_chatgpt_config.frequency_penalty = get_property_float(
ten_env, PROPERTY_FREQUENCY_PENALTY) or openai_chatgpt_config.frequency_penalty
openai_chatgpt_config.presence_penalty = get_property_float(
ten_env, PROPERTY_PRESENCE_PENALTY) or openai_chatgpt_config.presence_penalty
openai_chatgpt_config.temperature = get_property_float(
ten_env, PROPERTY_TEMPERATURE) or openai_chatgpt_config.temperature
openai_chatgpt_config.top_p = get_property_float(
ten_env, PROPERTY_TOP_P) or openai_chatgpt_config.top_p
openai_chatgpt_config.max_tokens = get_property_int(
ten_env, PROPERTY_MAX_TOKENS) or openai_chatgpt_config.max_tokens
openai_chatgpt_config.proxy_url = get_property_string(
ten_env, PROPERTY_PROXY_URL) or openai_chatgpt_config.proxy_url

# Properties that don't affect openai_chatgpt_config
greeting = get_property_string(ten_env, PROPERTY_GREETING)
self.greeting = get_property_string(ten_env, PROPERTY_GREETING)
self.enable_tools = get_property_bool(ten_env, PROPERTY_ENABLE_TOOLS)
self.max_memory_length = get_property_int(ten_env, PROPERTY_MAX_MEMORY_LENGTH)
checking_vision_text_items_str = get_property_string(ten_env, PROPERTY_CHECKING_VISION_TEXT_ITEMS)
self.max_memory_length = get_property_int(
ten_env, PROPERTY_MAX_MEMORY_LENGTH)
checking_vision_text_items_str = get_property_string(
ten_env, PROPERTY_CHECKING_VISION_TEXT_ITEMS)
if checking_vision_text_items_str:
try:
self.checking_vision_text_items = json.loads(checking_vision_text_items_str)
self.checking_vision_text_items = json.loads(
checking_vision_text_items_str)
except Exception as err:
logger.info(f"Error parsing {PROPERTY_CHECKING_VISION_TEXT_ITEMS}: {err}")
logger.info(
f"Error parsing {PROPERTY_CHECKING_VISION_TEXT_ITEMS}: {err}")

# Create instance
try:
self.openai_chatgpt = OpenAIChatGPT(openai_chatgpt_config)
logger.info(f"initialized with max_tokens: {openai_chatgpt_config.max_tokens}, model: {openai_chatgpt_config.model}")
logger.info(
f"initialized with max_tokens: {openai_chatgpt_config.max_tokens}, model: {openai_chatgpt_config.model}")
except Exception as err:
logger.info(f"Failed to initialize OpenAIChatGPT: {err}")

# Send greeting if available
if greeting:
try:
output_data = Data.create("text_data")
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting)
output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True)
ten_env.send_data(output_data)
logger.info(f"Greeting [{greeting}] sent")
except Exception as err:
logger.info(f"Failed to send greeting [{greeting}]: {err}")
ten_env.on_start_done()

def on_stop(self, ten_env: TenEnv) -> None:
Expand All @@ -155,19 +163,34 @@ def on_deinit(self, ten_env: TenEnv) -> None:
ten_env.on_deinit_done()

def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
logger.info(f"on_cmd json: {cmd.to_json()}")

cmd_name = cmd.get_name()
logger.info(f"on_cmd name: {cmd_name}")

if cmd_name == CMD_IN_FLUSH:
asyncio.run_coroutine_threadsafe(self._flush_queue(), self.loop)
ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH), None)
logger.info("on_cmd sent flush")
status_code, detail = StatusCode.OK, "success"
elif cmd_name == CMD_IN_ON_USER_JOINED:
# Send greeting if available
if self.greeting:
try:
output_data = Data.create("text_data")
output_data.set_property_string(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, self.greeting)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True)
ten_env.send_data(output_data)
logger.info(f"Greeting [{self.greeting}] sent")
except Exception as err:
logger.info(
f"Failed to send greeting [{self.greeting}]: {err}")

status_code, detail = StatusCode.OK, "success"
else:
logger.info(f"on_cmd unknown cmd: {cmd_name}")
status_code, detail = StatusCode.ERROR, "unknown cmd"

cmd_result = CmdResult.create(status_code)
cmd_result.set_property_string("detail", detail)
ten_env.return_result(cmd_result, cmd)
Expand All @@ -187,7 +210,8 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
logger.info(f"OnData input text: [{input_text}]")

# Start an asynchronous task for handling chat completion
asyncio.run_coroutine_threadsafe(self.queue.put([TASK_TYPE_CHAT_COMPLETION, input_text]), self.loop)
asyncio.run_coroutine_threadsafe(self.queue.put(
[TASK_TYPE_CHAT_COMPLETION, input_text]), self.loop)

def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
# TODO: process pcm frame
Expand All @@ -207,7 +231,8 @@ async def _process_queue(self, ten_env: TenEnv):
[task_type, message] = await self.queue.get()
try:
# Create a new task for the new message
self.current_task = asyncio.create_task(self._run_chatflow(ten_env, task_type, message, self.memory))
self.current_task = asyncio.create_task(
self._run_chatflow(ten_env, task_type, message, self.memory))
await self.current_task # Wait for the current task to finish or be cancelled
except asyncio.CancelledError:
logger.info(f"Task cancelled: {message}")
Expand All @@ -222,7 +247,7 @@ async def _flush_queue(self):
logger.info("Cancelling the current task during flush.")
self.current_task.cancel()

async def _run_chatflow(self, ten_env: TenEnv, task_type:str, input_text: str, memory):
async def _run_chatflow(self, ten_env: TenEnv, task_type: str, input_text: str, memory):
"""Run the chatflow asynchronously."""
memory_cache = []
try:
Expand All @@ -233,14 +258,17 @@ async def _run_chatflow(self, ten_env: TenEnv, task_type:str, input_text: str, m
# Prepare the message and tools based on the task type
if task_type == TASK_TYPE_CHAT_COMPLETION:
message = {"role": "user", "content": input_text}
memory_cache = memory_cache + [message, {"role": "assistant", "content": ""}]
memory_cache = memory_cache + \
[message, {"role": "assistant", "content": ""}]
tools = self.available_tools if self.enable_tools else None
elif task_type == TASK_TYPE_CHAT_COMPLETION_WITH_VISION:
message = {"role": "user", "content": input_text}
memory_cache = memory_cache + [message, {"role": "assistant", "content": ""}]
memory_cache = memory_cache + \
[message, {"role": "assistant", "content": ""}]
tools = self.available_tools if self.enable_tools else None
if self.image_data is not None:
url = rgb2base64jpeg(self.image_data, self.image_width, self.image_height)
url = rgb2base64jpeg(
self.image_data, self.image_width, self.image_height)
message = {
"role": "user",
"content": [
Expand All @@ -250,7 +278,6 @@ async def _run_chatflow(self, ten_env: TenEnv, task_type:str, input_text: str, m
}
logger.info(f"msg with vision data: {message}")


self.sentence_fragment = ""

# Create an asyncio.Event to signal when content is finished
Expand All @@ -263,17 +290,18 @@ async def handle_tool_call(tool_call):
# Append the vision image to the last assistant message
await self.queue.put([TASK_TYPE_CHAT_COMPLETION_WITH_VISION, input_text], True)

async def handle_content_update(content:str):
async def handle_content_update(content: str):
# Append the content to the last assistant message
for item in reversed(memory_cache):
if item.get('role') == 'assistant':
item['content'] = item['content'] + content
break
sentences, self.sentence_fragment = parse_sentences(self.sentence_fragment, content)
sentences, self.sentence_fragment = parse_sentences(
self.sentence_fragment, content)
for s in sentences:
self._send_data(ten_env, s, False)

async def handle_content_finished(full_content:str):
async def handle_content_finished(full_content: str):
content_finished_event.set()

listener = AsyncEventEmitter()
Expand All @@ -289,22 +317,24 @@ async def handle_content_finished(full_content:str):
except asyncio.CancelledError:
logger.info(f"Task cancelled: {input_text}")
except Exception as e:
logger.error(f"Error in chat_completion: {traceback.format_exc()} for input text: {input_text}")
logger.error(
f"Error in chat_completion: {traceback.format_exc()} for input text: {input_text}")
finally:
self._send_data(ten_env, "", True)
# always append the memory
for m in memory_cache:
self._append_memory(m)

def _append_memory(self, message:str):
def _append_memory(self, message: str):
if len(self.memory) > self.max_memory_length:
self.memory.pop(0)
self.memory.append(message)

def _send_data(self, ten_env: TenEnv, sentence: str, end_of_segment: bool):
try:
output_data = Data.create("text_data")
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
output_data.set_property_string(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, end_of_segment
)
Expand All @@ -315,4 +345,4 @@ def _send_data(self, ten_env: TenEnv, sentence: str, end_of_segment: bool):
except Exception as err:
logger.info(
f"send sentence [{sentence}] failed, err: {err}"
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
"cmd_in": [
{
"name": "flush"
},
{
"name": "on_user_joined"
}
],
"cmd_out": [
Expand Down

0 comments on commit eb874f9

Please sign in to comment.