Skip to content

Commit

Permalink
fix: recreate session when closed
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyoucao577 authored Oct 9, 2024
1 parent 1b805ee commit 396c006
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions agents/ten_packages/extension/openai_v2v_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@
TOOL_REGISTER_PROPERTY_DESCRIPTON = "description"
TOOL_REGISTER_PROPERTY_PARAMETERS = "parameters"


class Role(str, Enum):
User = "user"
Assistant = "assistant"


class OpenAIV2VExtension(Extension):
def __init__(self, name: str):
super().__init__(name)
Expand Down Expand Up @@ -101,7 +103,7 @@ def start_event_loop(loop):
self.thread = threading.Thread(
target=start_event_loop, args=(self.loop,))
self.thread.start()

self._register_local_tools()

asyncio.run_coroutine_threadsafe(self._init_connection(), self.loop)
Expand All @@ -123,7 +125,7 @@ def on_stop(self, ten_env: TenEnv) -> None:
def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
try:
stream_id = audio_frame.get_property_int("stream_id")
#logger.debug(f"on_audio_frame {stream_id}")
# logger.debug(f"on_audio_frame {stream_id}")
if self.channel_name == "":
self.channel_name = audio_frame.get_property_string("channel")

Expand Down Expand Up @@ -206,41 +208,41 @@ def get_time_ms() -> int:
# await self.conn.send_request(update_conversation)
case ItemInputAudioTranscriptionCompleted():
logger.info(
f"On request transript {message.transcript}")
f"On request transcript {message.transcript}")
self._send_transcript(
ten_env, message.transcript, Role.User, True)
case ItemInputAudioTranscriptionFailed():
logger.warning(
f"On request transript failed {message.item_id} {message.error}")
f"On request transcript failed {message.item_id} {message.error}")
case ItemCreated():
logger.info(f"On item created {message.item}")
case ResponseCreated():
response_id = message.response.id
logger.info(
f"On response created {response_id}")
case ResponseDone():
id = message.response.id
id = message.response.id
status = message.response.status
logger.info(
f"On response done {id} {status}")
if id == response_id:
response_id = ""
case ResponseAudioTranscriptDelta():
logger.info(
f"On response transript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}")
f"On response transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}")
if message.response_id in flushed:
logger.warning(
f"On flushed transript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}")
f"On flushed transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}")
continue
self.transcript += message.delta
self._send_transcript(
ten_env, self.transcript, Role.Assistant, False)
case ResponseAudioTranscriptDone():
logger.info(
f"On response transript done {message.output_index} {message.content_index} {message.transcript}")
f"On response transcript done {message.output_index} {message.content_index} {message.transcript}")
if message.response_id in flushed:
logger.warning(
f"On flushed transript done {message.response_id}")
f"On flushed transcript done {message.response_id}")
continue
self.transcript = ""
self._send_transcript(
Expand Down Expand Up @@ -285,7 +287,8 @@ def get_time_ms() -> int:
name = message.name
arguments = message.arguments
logger.info(f"need to call func {name}")
await self.registry.on_func_call(tool_call_id, name, arguments, self._on_tool_output) # TODO rebuild this into async, or it will block the thread
# TODO rebuild this into async, or it will block the thread
await self.registry.on_func_call(tool_call_id, name, arguments, self._on_tool_output)
case ErrorMessage():
logger.error(
f"Error message received: {message.error}")
Expand All @@ -299,6 +302,10 @@ def get_time_ms() -> int:
except:
logger.exception(f"Failed to handle loop")

# clear so that new session can be triggered
self.connected = False
self.remote_stream_id = 0

async def _on_audio(self, buff: bytearray):
self.out_audio_buff += buff
# Buffer audio
Expand Down Expand Up @@ -371,7 +378,7 @@ def _fetch_properties(self, ten_env: TenEnv):
except Exception as err:
logger.info(
f"GetProperty optional {PROPERTY_LANGUAGE} error: {err}")

try:
greeting = ten_env.get_property_string(PROPERTY_GREETING)
if greeting:
Expand Down Expand Up @@ -407,7 +414,7 @@ def _update_session(self) -> SessionUpdate:
model="whisper-1"),
tool_choice="auto",
tools=self.registry.get_tools()
))
))

'''
def _update_conversation(self) -> UpdateConversationConfig:
Expand Down Expand Up @@ -481,7 +488,8 @@ def _register_local_tools(self) -> None:
def _on_tool_register(self, ten_env: TenEnv, cmd: Cmd):
try:
name = cmd.get_property_string(TOOL_REGISTER_PROPERTY_NAME)
description = cmd.get_property_string(TOOL_REGISTER_PROPERTY_DESCRIPTON)
description = cmd.get_property_string(
TOOL_REGISTER_PROPERTY_DESCRIPTON)
pstr = cmd.get_property_string(TOOL_REGISTER_PROPERTY_PARAMETERS)
parameters = json.loads(pstr)
p = partial(self._remote_tool_call, ten_env)
Expand All @@ -494,16 +502,16 @@ def _on_tool_register(self, ten_env: TenEnv, cmd: Cmd):
except:
logger.exception(f"Failed to register")

async def _remote_tool_call(self, ten_env: TenEnv, name:str, args: str, callback: Awaitable):
async def _remote_tool_call(self, ten_env: TenEnv, name: str, args: str, callback: Awaitable):
logger.info(f"_remote_tool_call {name} {args}")
c = Cmd.create(CMD_TOOL_CALL)
c.set_property_string(CMD_PROPERTY_NAME, name)
c.set_property_string(CMD_PROPERTY_ARGS, args)
ten_env.send_cmd(c, lambda ten, result: asyncio.run_coroutine_threadsafe(
callback(result.get_property_string("response")), self.loop))
callback(result.get_property_string("response")), self.loop))
logger.info(f"_remote_tool_call finish {name} {args}")
async def _on_tool_output(self, tool_call_id:str, result: str):

async def _on_tool_output(self, tool_call_id: str, result: str):
logger.info(f"_on_tool_output {tool_call_id} {result}")
try:
tool_response = ItemCreate(
Expand All @@ -526,4 +534,4 @@ def _greeting_text(self) -> str:
text = "こんにちは"
elif self.config.language == "ko-KR":
text = "안녕하세요"
return text
return text

0 comments on commit 396c006

Please sign in to comment.