diff --git a/agents/ten_packages/extension/openai_v2v_python/extension.py b/agents/ten_packages/extension/openai_v2v_python/extension.py index 1457498a..e4050438 100644 --- a/agents/ten_packages/extension/openai_v2v_python/extension.py +++ b/agents/ten_packages/extension/openai_v2v_python/extension.py @@ -55,10 +55,12 @@ TOOL_REGISTER_PROPERTY_DESCRIPTON = "description" TOOL_REGISTER_PROPERTY_PARAMETERS = "parameters" + class Role(str, Enum): User = "user" Assistant = "assistant" + class OpenAIV2VExtension(Extension): def __init__(self, name: str): super().__init__(name) @@ -101,7 +103,7 @@ def start_event_loop(loop): self.thread = threading.Thread( target=start_event_loop, args=(self.loop,)) self.thread.start() - + self._register_local_tools() asyncio.run_coroutine_threadsafe(self._init_connection(), self.loop) @@ -123,7 +125,7 @@ def on_stop(self, ten_env: TenEnv) -> None: def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: try: stream_id = audio_frame.get_property_int("stream_id") - #logger.debug(f"on_audio_frame {stream_id}") + # logger.debug(f"on_audio_frame {stream_id}") if self.channel_name == "": self.channel_name = audio_frame.get_property_string("channel") @@ -206,12 +208,12 @@ def get_time_ms() -> int: # await self.conn.send_request(update_conversation) case ItemInputAudioTranscriptionCompleted(): logger.info( - f"On request transript {message.transcript}") + f"On request transcript {message.transcript}") self._send_transcript( ten_env, message.transcript, Role.User, True) case ItemInputAudioTranscriptionFailed(): logger.warning( - f"On request transript failed {message.item_id} {message.error}") + f"On request transcript failed {message.item_id} {message.error}") case ItemCreated(): logger.info(f"On item created {message.item}") case ResponseCreated(): @@ -219,7 +221,7 @@ def get_time_ms() -> int: logger.info( f"On response created {response_id}") case ResponseDone(): - id = message.response.id + id = message.response.id status = message.response.status logger.info( f"On response done {id} {status}") @@ -227,20 +229,20 @@ def get_time_ms() -> int: response_id = "" case ResponseAudioTranscriptDelta(): logger.info( - f"On response transript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") + f"On response transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") if message.response_id in flushed: logger.warning( - f"On flushed transript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") + f"On flushed transcript delta {message.response_id} {message.output_index} {message.content_index} {message.delta}") continue self.transcript += message.delta self._send_transcript( ten_env, self.transcript, Role.Assistant, False) case ResponseAudioTranscriptDone(): logger.info( - f"On response transript done {message.output_index} {message.content_index} {message.transcript}") + f"On response transcript done {message.output_index} {message.content_index} {message.transcript}") if message.response_id in flushed: logger.warning( - f"On flushed transript done {message.response_id}") + f"On flushed transcript done {message.response_id}") continue self.transcript = "" self._send_transcript( @@ -285,7 +287,8 @@ def get_time_ms() -> int: name = message.name arguments = message.arguments logger.info(f"need to call func {name}") - await self.registry.on_func_call(tool_call_id, name, arguments, self._on_tool_output) # TODO rebuild this into async, or it will block the thread + # TODO rebuild this into async, or it will block the thread + await self.registry.on_func_call(tool_call_id, name, arguments, self._on_tool_output) case ErrorMessage(): logger.error( f"Error message received: {message.error}") @@ -299,6 +302,10 @@ def get_time_ms() -> int: except: logger.exception(f"Failed to handle loop") + # clear so that new session can be triggered + self.connected = False + self.remote_stream_id = 0 + async def _on_audio(self, buff: bytearray): self.out_audio_buff += buff # Buffer audio @@ -371,7 +378,7 @@ def _fetch_properties(self, ten_env: TenEnv): except Exception as err: logger.info( f"GetProperty optional {PROPERTY_LANGUAGE} error: {err}") - + try: greeting = ten_env.get_property_string(PROPERTY_GREETING) if greeting: @@ -407,7 +414,7 @@ def _update_session(self) -> SessionUpdate: model="whisper-1"), tool_choice="auto", tools=self.registry.get_tools() - )) + )) ''' def _update_conversation(self) -> UpdateConversationConfig: @@ -481,7 +488,8 @@ def _register_local_tools(self) -> None: def _on_tool_register(self, ten_env: TenEnv, cmd: Cmd): try: name = cmd.get_property_string(TOOL_REGISTER_PROPERTY_NAME) - description = cmd.get_property_string(TOOL_REGISTER_PROPERTY_DESCRIPTON) + description = cmd.get_property_string( + TOOL_REGISTER_PROPERTY_DESCRIPTON) pstr = cmd.get_property_string(TOOL_REGISTER_PROPERTY_PARAMETERS) parameters = json.loads(pstr) p = partial(self._remote_tool_call, ten_env) @@ -494,16 +502,16 @@ def _on_tool_register(self, ten_env: TenEnv, cmd: Cmd): except: logger.exception(f"Failed to register") - async def _remote_tool_call(self, ten_env: TenEnv, name:str, args: str, callback: Awaitable): + async def _remote_tool_call(self, ten_env: TenEnv, name: str, args: str, callback: Awaitable): logger.info(f"_remote_tool_call {name} {args}") c = Cmd.create(CMD_TOOL_CALL) c.set_property_string(CMD_PROPERTY_NAME, name) c.set_property_string(CMD_PROPERTY_ARGS, args) ten_env.send_cmd(c, lambda ten, result: asyncio.run_coroutine_threadsafe( - callback(result.get_property_string("response")), self.loop)) + callback(result.get_property_string("response")), self.loop)) logger.info(f"_remote_tool_call finish {name} {args}") - - async def _on_tool_output(self, tool_call_id:str, result: str): + + async def _on_tool_output(self, tool_call_id: str, result: str): logger.info(f"_on_tool_output {tool_call_id} {result}") try: tool_response = ItemCreate( @@ -526,4 +534,4 @@ def _greeting_text(self) -> str: text = "こんにちは" elif self.config.language == "ko-KR": text = "안녕하세요" - return text \ No newline at end of file + return text