diff --git a/agents/.gitignore b/agents/.gitignore index 175ad523..7e8da43e 100644 --- a/agents/.gitignore +++ b/agents/.gitignore @@ -2,6 +2,7 @@ ten_packages/extension_group/ ten_packages/extension/agora_rtc ten_packages/extension/azure_tts +ten_packages/extension/agora_sess_ctrl ten_packages/extension/py_init_extension_cpp ten_packages/system .ten diff --git a/agents/manifest-lock.json b/agents/manifest-lock.json index 1398fb18..01873ebb 100644 --- a/agents/manifest-lock.json +++ b/agents/manifest-lock.json @@ -67,6 +67,24 @@ } ] }, + { + "type": "extension", + "name": "agora_sess_ctrl", + "version": "0.2.0", + "hash": "52f26dee2fb8fbd22d55ecc6bb197176a51f4a3bd2268788f75582f68cf1270b", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime" + } + ], + "supports": [ + { + "os": "linux", + "arch": "x64" + } + ] + }, { "type": "system", "name": "azure_speech_sdk", diff --git a/agents/manifest.json b/agents/manifest.json index f92fe394..814cd3c6 100644 --- a/agents/manifest.json +++ b/agents/manifest.json @@ -18,6 +18,11 @@ "name": "agora_rtc", "version": "=0.8.0-rc2" }, + { + "type": "extension", + "name": "agora_sess_ctrl", + "version": "0.2.0" + }, { "type": "system", "name": "azure_speech_sdk", diff --git a/agents/property.json b/agents/property.json index 16c26867..c9790571 100644 --- a/agents/property.json +++ b/agents/property.json @@ -3474,7 +3474,160 @@ ] } ] + }, + { + "name": "va_minimax_v2v", + "auto_start": false, + "nodes": [ + { + "type": "extension", + "extension_group": "rtc", + "addon": "agora_rtc", + "name": "agora_rtc", + "property": { + "app_id": "${env:AGORA_APP_ID}", + "token": "", + "channel": "ten_agent_test", + "stream_id": 1234, + "remote_stream_id": 123, + "subscribe_audio": true, + "publish_audio": true, + "publish_data": true + } + }, + { + "type": "extension", + "extension_group": "agora_sess_ctrl", + "addon": "agora_sess_ctrl", + "name": "agora_sess_ctrl", + "property": { + "wait_for_eos": true + } + }, + { + "type": "extension", + "extension_group": "llm", + "addon": "minimax_v2v_python", + "name": "minimax_v2v_python", + "property": { + "in_sample_rate": 16000, + "token": "${env:MINIMAX_TOKEN}" + } + }, + { + "type": "extension", + "extension_group": "message_collector", + "addon": "message_collector", + "name": "message_collector" + } + ], + "connections": [ + { + "extension_group": "rtc", + "extension": "agora_rtc", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension_group": "agora_sess_ctrl", + "extension": "agora_sess_ctrl" + } + ] + } + ] + }, + { + "extension_group": "agora_sess_ctrl", + "extension": "agora_sess_ctrl", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension_group": "llm", + "extension": "minimax_v2v_python" + } + ] + } + ], + "cmd": [ + { + "name": "start_of_sentence", + "dest": [ + { + "extension_group": "llm", + "extension": "minimax_v2v_python", + "msg_conversion": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "_ten.name", + "conversion_mode": "fixed_value", + "value": "flush" + } + ] + } + } + ] + } + ] + }, + { + "extension_group": "llm", + "extension": "minimax_v2v_python", + "data": [ + { + "name": "text_data", + "dest": [ + { + "extension_group": "message_collector", + "extension": "message_collector" + } + ] + } + ], + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension_group": "rtc", + "extension": "agora_rtc" + } + ] + } + ], + "cmd": [ + { + "name": "flush", + "dest": [ + { + "extension_group": "rtc", + "extension": "agora_rtc" + } + ] + } + ] + }, + { + "extension_group": "message_collector", + "extension": "message_collector", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "rtc", + "extension": "agora_rtc" + } + ] + } + ] + } + ] } ] } -} +} \ No newline at end of file diff --git a/agents/ten_packages/extension/message_collector/src/extension.py b/agents/ten_packages/extension/message_collector/src/extension.py index 7013c432..f54592b0 100644 --- a/agents/ten_packages/extension/message_collector/src/extension.py +++ b/agents/ten_packages/extension/message_collector/src/extension.py @@ -5,7 +5,9 @@ # Copyright (c) 2024 Agora IO. All rights reserved. # # +import base64 import json +import threading import time import uuid from ten import ( @@ -18,6 +20,7 @@ CmdResult, Data, ) +import asyncio from .log import logger MAX_SIZE = 800 # 1 KB limit @@ -32,9 +35,70 @@ # record the cached text data for each stream id cached_text_map = {} +MAX_CHUNK_SIZE_BYTES = 1024 + +def _text_to_base64_chunks(text: str, msg_id: str) -> list: + # Ensure msg_id does not exceed 50 characters + if len(msg_id) > 36: + raise ValueError("msg_id cannot exceed 36 characters.") + + # Convert text to bytearray + byte_array = bytearray(text, 'utf-8') + + # Encode the bytearray into base64 + base64_encoded = base64.b64encode(byte_array).decode('utf-8') + + # Initialize list to hold the final chunks + chunks = [] + + # We'll split the base64 string dynamically based on the final byte size + part_index = 0 + total_parts = None # We'll calculate total parts once we know how many chunks we create + + # Process the base64-encoded content in chunks + current_position = 0 + total_length = len(base64_encoded) + + while current_position < total_length: + part_index += 1 + + # Start guessing the chunk size by limiting the base64 content part + estimated_chunk_size = MAX_CHUNK_SIZE_BYTES # We'll reduce this dynamically + content_chunk = "" + count = 0 + while True: + # Create the content part of the chunk + content_chunk = base64_encoded[current_position:current_position + estimated_chunk_size] + + # Format the chunk + formatted_chunk = f"{msg_id}|{part_index}|{total_parts if total_parts else '???'}|{content_chunk}" + + # Check if the byte length of the formatted chunk exceeds the max allowed size + if len(bytearray(formatted_chunk, 'utf-8')) <= MAX_CHUNK_SIZE_BYTES: + break + else: + # Reduce the estimated chunk size if the formatted chunk is too large + estimated_chunk_size -= 100 # Reduce content size gradually + count += 1 + + logger.debug(f"chunk estimate guess: {count}") + + # Add the current chunk to the list + chunks.append(formatted_chunk) + current_position += estimated_chunk_size # Move to the next part of the content + # Now that we know the total number of parts, update the chunks with correct total_parts + total_parts = len(chunks) + updated_chunks = [ + chunk.replace("???", str(total_parts)) for chunk in chunks + ] + + return updated_chunks class MessageCollectorExtension(Extension): + # Create the queue for message processing + queue = asyncio.Queue() + def on_init(self, ten_env: TenEnv) -> None: logger.info("MessageCollectorExtension on_init") ten_env.on_init_done() @@ -43,6 +107,13 @@ def on_start(self, ten_env: TenEnv) -> None: logger.info("MessageCollectorExtension on_start") # TODO: read properties, initialize resources + self.loop = asyncio.new_event_loop() + def start_loop(): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + threading.Thread(target=start_loop, args=[]).start() + + self.loop.create_task(self._process_queue(ten_env)) ten_env.on_start_done() @@ -74,7 +145,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None: example: {"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}} """ - logger.info(f"on_data") + logger.debug(f"on_data") text = "" final = True @@ -123,7 +194,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None: cached_text_map[stream_id] = text # Generate a unique message ID for this batch of parts - message_id = str(uuid.uuid4()) + message_id = str(uuid.uuid4())[:8] # Prepare the main JSON structure without the text field base_msg_data = { @@ -132,61 +203,13 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None: "message_id": message_id, # Add message_id to identify the split message "data_type": "transcribe", "text_ts": int(time.time() * 1000), # Convert to milliseconds + "text": text, } try: - # Convert the text to UTF-8 bytes - text_bytes = text.encode('utf-8') - - # If the text + metadata fits within the size limit, send it directly - if len(text_bytes) + OVERHEAD_ESTIMATE <= MAX_SIZE: - base_msg_data["text"] = text - msg_data = json.dumps(base_msg_data) - ten_data = Data.create("data") - ten_data.set_property_buf("data", msg_data.encode()) - ten_env.send_data(ten_data) - else: - # Split the text bytes into smaller chunks, ensuring safe UTF-8 splitting - max_text_size = MAX_SIZE - OVERHEAD_ESTIMATE - total_length = len(text_bytes) - total_parts = (total_length + max_text_size - 1) // max_text_size # Calculate number of parts - - def get_valid_utf8_chunk(start, end): - """Helper function to ensure valid UTF-8 chunks.""" - while end > start: - try: - # Decode to check if this chunk is valid UTF-8 - text_part = text_bytes[start:end].decode('utf-8') - return text_part, end - except UnicodeDecodeError: - # Reduce the end point to avoid splitting in the middle of a character - end -= 1 - # If no valid chunk is found (shouldn't happen with valid UTF-8 input), return an empty string - return "", start - - part_number = 0 - start_index = 0 - while start_index < total_length: - part_number += 1 - # Get a valid UTF-8 chunk - text_part, end_index = get_valid_utf8_chunk(start_index, min(start_index + max_text_size, total_length)) - - # Prepare the part data with metadata - part_data = base_msg_data.copy() - part_data.update({ - "text": text_part, - "part_number": part_number, - "total_parts": total_parts, - }) - - # Send each part - part_msg_data = json.dumps(part_data) - ten_data = Data.create("data") - ten_data.set_property_buf("data", part_msg_data.encode()) - ten_env.send_data(ten_data) - - # Move to the next chunk - start_index = end_index + chunks = _text_to_base64_chunks(json.dumps(base_msg_data), message_id) + for chunk in chunks: + asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop) except Exception as e: logger.warning(f"on_data new_data error: {e}") @@ -199,3 +222,19 @@ def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: # TODO: process image frame pass + + + async def _queue_message(self, data: str): + await self.queue.put(data) + + async def _process_queue(self, ten_env: TenEnv): + while True: + data = await self.queue.get() + if data is None: + break + # process data + ten_data = Data.create("data") + ten_data.set_property_buf("data", data.encode()) + ten_env.send_data(ten_data) + self.queue.task_done() + await asyncio.sleep(0.04) \ No newline at end of file diff --git a/agents/ten_packages/extension/minimax_v2v_python/README.md b/agents/ten_packages/extension/minimax_v2v_python/README.md new file mode 100644 index 00000000..c73a53c1 --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/README.md @@ -0,0 +1,36 @@ +# MiniMax Voice-to-Voice Extension + +A TEN extension that implements voice-to-voice conversation capabilities using MiniMax's API services. + +## Features + +- Real-time voice-to-voice conversation +- Support for streaming responses including assistant's voice, assisntant's transcript, and user's transcript +- Configurable voice settings +- Memory management for conversation context +- Asynchronous processing based on asyncio + + +## API + +Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). +`token` is mandatory to use MiniMax's API, others are optional. + + + +## Development + +### Build + + + +### Unit test + + + +## Misc + + + +## References +- [ChatCompletion v2](https://platform.minimaxi.com/document/ChatCompletion%20v2?key=66701d281d57f38758d581d0#ww1u9KZvwrgnF2EfpPrnHHGd) diff --git a/agents/ten_packages/extension/minimax_v2v_python/__init__.py b/agents/ten_packages/extension/minimax_v2v_python/__init__.py new file mode 100644 index 00000000..72593ab2 --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/__init__.py @@ -0,0 +1,6 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon diff --git a/agents/ten_packages/extension/minimax_v2v_python/addon.py b/agents/ten_packages/extension/minimax_v2v_python/addon.py new file mode 100644 index 00000000..f6c0f882 --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/addon.py @@ -0,0 +1,18 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten import ( + Addon, + register_addon_as_extension, + TenEnv, +) +from .extension import MinimaxV2VExtension + + +@register_addon_as_extension("minimax_v2v_python") +class MinimaxV2VExtensionAddon(Addon): + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + ten_env.log_info("on_create_instance") + ten_env.on_create_instance_done(MinimaxV2VExtension(name), context) diff --git a/agents/ten_packages/extension/minimax_v2v_python/chat_memory.py b/agents/ten_packages/extension/minimax_v2v_python/chat_memory.py new file mode 100644 index 00000000..8ef98b10 --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/chat_memory.py @@ -0,0 +1,42 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import threading + + +class ChatMemory: + def __init__(self, max_history_length): + self.max_history_length = max_history_length + self.history = [] + self.mutex = threading.Lock() # TODO: no need lock for asyncio + + def put(self, message): + with self.mutex: + self.history.append(message) + + while True: + history_count = len(self.history) + if history_count > 0 and history_count > self.max_history_length: + self.history.pop(0) + continue + if history_count > 0 and self.history[0]["role"] == "assistant": + # we cannot have an assistant message at the start of the chat history + # if after removal of the first, we have an assistant message, + # we need to remove the assistant message too + self.history.pop(0) + continue + break + + def get(self): + with self.mutex: + return self.history + + def count(self): + with self.mutex: + return len(self.history) + + def clear(self): + with self.mutex: + self.history = [] diff --git a/agents/ten_packages/extension/minimax_v2v_python/extension.py b/agents/ten_packages/extension/minimax_v2v_python/extension.py new file mode 100644 index 00000000..64143b69 --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/extension.py @@ -0,0 +1,484 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten import ( + AudioFrame, + VideoFrame, + AudioFrameDataFmt, + AsyncExtension, + AsyncTenEnv, + Cmd, + StatusCode, + CmdResult, + Data, +) +from .util import duration_in_ms, duration_in_ms_since, Role +from .chat_memory import ChatMemory +from dataclasses import dataclass, fields +import builtins +import httpx +from datetime import datetime +import aiofiles +import asyncio +from typing import Iterator, List, Dict, Tuple, Any +import base64 +import json + + +@dataclass +class MinimaxV2VConfig: + token: str = "" + max_tokens: int = 1024 + model: str = "abab6.5s-chat" + voice_model: str = "speech-01-turbo-240228" + voice_id: str = "female-tianmei" + in_sample_rate: int = 16000 + out_sample_rate: int = 32000 + prompt: str = ( + "You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points." + ) + greeting: str = "" + max_memory_length: int = 10 + dump: bool = False + + def read_from_property(self, ten_env: AsyncTenEnv): + for field in fields(self): + # TODO: 'is_property_exist' has a bug that can not be used in async extension currently, use it instead of try .. except once fixed + # if not ten_env.is_property_exist(field.name): + # continue + try: + match field.type: + case builtins.str: + val = ten_env.get_property_string(field.name) + if val: + setattr(self, field.name, val) + ten_env.log_info(f"{field.name}={val}") + case builtins.int: + val = ten_env.get_property_int(field.name) + setattr(self, field.name, val) + ten_env.log_info(f"{field.name}={val}") + case builtins.bool: + val = ten_env.get_property_bool(field.name) + setattr(self, field.name, val) + ten_env.log_info(f"{field.name}={val}") + case _: + pass + except Exception as e: + ten_env.log_warn(f"get property for {field.name} failed, err {e}") + + +class MinimaxV2VExtension(AsyncExtension): + def __init__(self, name: str) -> None: + super().__init__(name) + + self.config = MinimaxV2VConfig() + self.client = httpx.AsyncClient(timeout=httpx.Timeout(5)) + self.memory = ChatMemory(self.config.max_memory_length) + self.remote_stream_id = 0 + self.ten_env = None + + # able to cancel + self.curr_task = None + + # make sure tasks processing in order + self.process_input_task = None + self.queue = asyncio.Queue() + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + self.config.read_from_property(ten_env=ten_env) + ten_env.log_info(f"config: {self.config}") + + self.memory = ChatMemory(self.config.max_memory_length) + self.ten_env = ten_env + ten_env.on_init_done() + + async def on_start(self, ten_env: AsyncTenEnv) -> None: + self.process_input_task = asyncio.create_task(self._process_input(ten_env=ten_env, queue=self.queue), name="process_input") + + ten_env.on_start_done() + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + + await self._flush(ten_env=ten_env) + self.queue.put_nowait(None) + if self.process_input_task: + self.process_input_task.cancel() + await asyncio.gather(self.process_input_task, return_exceptions=True) + self.process_input_task = None + + ten_env.on_stop_done() + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + ten_env.log_debug("on_deinit") + + if self.client: + await self.client.aclose() + self.client = None + self.ten_env = None + ten_env.on_deinit_done() + + async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: + try: + cmd_name = cmd.get_name() + ten_env.log_debug("on_cmd name {}".format(cmd_name)) + + # process cmd + match cmd_name: + case "flush": + await self._flush(ten_env=ten_env) + _result = await ten_env.send_cmd(Cmd.create("flush")) + ten_env.log_debug("flush done") + case _: + pass + ten_env.return_result(CmdResult.create(StatusCode.OK), cmd) + except asyncio.CancelledError: + ten_env.log_warn(f"cmd {cmd_name} cancelled") + ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd) + raise + except Exception as e: + ten_env.log_warn(f"cmd {cmd_name} failed, err {e}") + finally: + pass + + async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None: + pass + + async def on_audio_frame( + self, ten_env: AsyncTenEnv, audio_frame: AudioFrame + ) -> None: + + try: + ts = datetime.now() + stream_id = audio_frame.get_property_int("stream_id") + if not self.remote_stream_id: + self.remote_stream_id = stream_id + + frame_buf = audio_frame.get_buf() + ten_env.log_debug(f"on audio frame {len(frame_buf)} {stream_id}") + + # process audio frame, must be after vad + # put_nowait to make sure put in_order + self.queue.put_nowait((ts, frame_buf)) + # await self._complete_with_history(ts, frame_buf) + + # dump input audio if need + await self._dump_audio_if_need(frame_buf, "in") + + # ten_env.log_debug(f"on audio frame {len(frame_buf)} {stream_id} put done") + except asyncio.CancelledError: + ten_env.log_warn(f"on audio frame cancelled") + raise + except Exception as e: + ten_env.log_error(f"on audio frame failed, err {e}") + + async def on_video_frame( + self, ten_env: AsyncTenEnv, video_frame: VideoFrame + ) -> None: + pass + + async def _process_input(self, ten_env: AsyncTenEnv, queue: asyncio.Queue): + ten_env.log_info("process_input started") + + while True: + item = await queue.get() + if not item: + break + + (ts, frame_buf) = item + ten_env.log_debug(f"start process task {ts} {len(frame_buf)}") + + try: + self.curr_task = asyncio.create_task(self._complete_with_history(ts, frame_buf)) + await self.curr_task + self.curr_task = None + except asyncio.CancelledError: + ten_env.log_warn("task cancelled") + except Exception as e: + ten_env.log_warn(f"task failed, err {e}") + finally: + queue.task_done() + + ten_env.log_info("process_input exit") + + async def _complete_with_history( + self, ts: datetime, buff: bytearray + ): + start_time = datetime.now() + ten_env = self.ten_env + ten_env.log_debug( + f"start request, buff len {len(buff)}, queued_time {duration_in_ms(ts, start_time)}ms" + ) + + # prepare messages with prompt and history + messages = [] + if self.config.prompt: + messages.append({"role": Role.System, "content": self.config.prompt}) + messages.extend(self.memory.get()) + ten_env.log_debug(f"messages without audio: [{messages}]") + messages.append( + self._create_input_audio_message(buff=buff) + ) # don't print audio message + + # prepare request + url = "https://api.minimax.chat/v1/text/chatcompletion_v2" + (headers, payload) = self._create_request(messages) + + # vars to calculate Time to first byte + user_transcript_ttfb = None + assistant_transcript_ttfb = None + assistant_audio_ttfb = None + + # vars for transcript + user_transcript = "" + assistant_transcript = "" + + try: + # send POST request + async with self.client.stream( + "POST", url, headers=headers, json=payload + ) as response: + trace_id = response.headers.get("Trace-Id", "") + alb_receive_time = response.headers.get("alb_receive_time", "") + ten_env.log_info( + f"Get response trace-id: {trace_id}, alb_receive_time: {alb_receive_time}, cost_time {duration_in_ms_since(start_time)}ms" + ) + + response.raise_for_status() # check response + + i = 0 + async for line in response.aiter_lines(): + # logger.info(f"-> line {line}") + # if self._need_interrupt(ts): + # ten_env.log_warn(f"trace-id: {trace_id}, interrupted") + # if self.transcript: + # self.transcript += "[interrupted]" + # self._append_message("assistant", self.transcript) + # self._send_transcript("", "assistant", True) + # break + + if not line.startswith("data:"): + ten_env.log_debug(f"ignore line {len(line)}") + continue + i += 1 + + resp = json.loads(line.strip("data:")) + if resp.get("choices") and resp["choices"][0].get("delta"): + delta = resp["choices"][0]["delta"] + if delta.get("role") == "assistant": + # text content + if delta.get("content"): + content = delta["content"] + assistant_transcript += content + if not assistant_transcript_ttfb: + assistant_transcript_ttfb = duration_in_ms_since( + start_time + ) + ten_env.log_info( + f"trace-id {trace_id} chunck-{i} get assistant_transcript_ttfb {assistant_transcript_ttfb}ms, assistant transcript [{content}]" + ) + else: + ten_env.log_info( + f"trace-id {trace_id} chunck-{i} get assistant transcript [{content}]" + ) + + # send out for transcript display + self._send_transcript( + ten_env=ten_env, + content=content, + role=Role.Assistant, + end_of_segment=False, + ) + + # audio content + if ( + delta.get("audio_content") + and delta["audio_content"] != "" + ): + ten_env.log_info( + f"trace-id {trace_id} chunck-{i} get audio_content" + ) + if not assistant_audio_ttfb: + assistant_audio_ttfb = duration_in_ms_since( + start_time + ) + ten_env.log_info( + f"trace-id {trace_id} chunck-{i} get assistant_audio_ttfb {assistant_audio_ttfb}ms" + ) + + # send out + base64_str = delta["audio_content"] + buff = base64.b64decode(base64_str) + await self._dump_audio_if_need(buff, "out") + self._send_audio_frame(ten_env=ten_env, audio_data=buff) + + # tool calls + if delta.get("tool_calls"): + ten_env.log_warn(f"ignore tool call {delta}") + # TODO: add tool calls + continue + + if delta.get("role") == "user": + if delta.get("content"): + content = delta["content"] + user_transcript += content + if not user_transcript_ttfb: + user_transcript_ttfb = duration_in_ms_since( + start_time + ) + ten_env.log_info( + f"trace-id: {trace_id} chunck-{i} get user_transcript_ttfb {user_transcript_ttfb}ms, user transcript [{content}]" + ) + else: + ten_env.log_info( + f"trace-id {trace_id} chunck-{i} get user transcript [{content}]" + ) + + # send out for transcript display + self._send_transcript( + ten_env=ten_env, + content=content, + role=Role.User, + end_of_segment=True, + ) + + except httpx.TimeoutException: + ten_env.log_warn("http timeout") + except httpx.HTTPStatusError as e: + ten_env.log_warn(f"http status error: {e}") + except httpx.RequestError as e: + ten_env.log_warn(f"http request error: {e}") + finally: + ten_env.log_info( + f"http loop done, cost_time {duration_in_ms_since(start_time)}ms" + ) + if user_transcript: + self.memory.put({"role": Role.User, "content": user_transcript}) + if assistant_transcript: + self.memory.put( + {"role": Role.Assistant, "content": assistant_transcript} + ) + self._send_transcript( + ten_env=ten_env, + content="", + role=Role.Assistant, + end_of_segment=True, + ) + + def _create_input_audio_message(self, buff: bytearray) -> Dict[str, Any]: + message = { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64.b64encode(buff).decode("utf-8"), + "format": "pcm", + "sample_rate": self.config.in_sample_rate, + "bit_depth": 16, + "channel": 1, + "encode": "base64", + }, + } + ], + } + return message + + def _create_request( + self, messages: List[Any] + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + config = self.config + + headers = { + "Authorization": f"Bearer {config.token}", + "Content-Type": "application/json", + } + + payload = { + "model": config.model, + "messages": messages, + "tools": [], + "tool_choice": "none", + "stream": True, + "stream_options": {"speech_output": True}, # 开启语音输出 + "voice_setting": { + "model": config.voice_model, + "voice_id": config.voice_id, + }, + "audio_setting": { + "sample_rate": config.out_sample_rate, + "format": "pcm", + "channel": 1, + "encode": "base64", + }, + "tools": [{"type": "web_search"}], + "max_tokens": config.max_tokens, + "temperature": 0.8, + "top_p": 0.95, + } + + return (headers, payload) + + def _send_audio_frame(self, ten_env: AsyncTenEnv, audio_data: bytearray) -> None: + try: + f = AudioFrame.create("pcm_frame") + f.set_sample_rate(self.config.out_sample_rate) + f.set_bytes_per_sample(2) + f.set_number_of_channels(1) + f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) + f.set_samples_per_channel(len(audio_data) // 2) + f.alloc_buf(len(audio_data)) + buff = f.lock_buf() + buff[:] = audio_data + f.unlock_buf(buff) + ten_env.send_audio_frame(f) + except Exception as e: + ten_env.log_error(f"send audio frame failed, err {e}") + + def _send_transcript( + self, + ten_env: AsyncTenEnv, + content: str, + role: str, + end_of_segment: bool, + ) -> None: + stream_id = self.remote_stream_id if role == "user" else 0 + + try: + d = Data.create("text_data") + d.set_property_string("text", content) + d.set_property_bool("is_final", True) + d.set_property_bool("end_of_segment", end_of_segment) + d.set_property_string("role", role) + d.set_property_int("stream_id", stream_id) + ten_env.log_info( + f"send transcript text [{content}] {stream_id} end_of_segment {end_of_segment} role {role}" + ) + self.ten_env.send_data(d) + except Exception as e: + ten_env.log_warn( + f"send transcript text [{content}] {stream_id} end_of_segment {end_of_segment} role {role} failed, err {e}" + ) + + async def _flush(self, ten_env: AsyncTenEnv) -> None: + # clear queue + while not self.queue.empty(): + try: + self.queue.get_nowait() + self.queue.task_done() + except Exception as e: + ten_env.log_warn("flush queue error {e}") + + # cancel current task + if self.curr_task: + self.curr_task.cancel() + await asyncio.gather(self.curr_task, return_exceptions=True) + self.curr_task = None + + async def _dump_audio_if_need(self, buf: bytearray, suffix: str) -> None: + if not self.config.dump: + return + + async with aiofiles.open(f"minimax_v2v_{suffix}.pcm", "ab") as f: + await f.write(buf) diff --git a/agents/ten_packages/extension/minimax_v2v_python/manifest.json b/agents/ten_packages/extension/minimax_v2v_python/manifest.json new file mode 100644 index 00000000..23b4432b --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/manifest.json @@ -0,0 +1,104 @@ +{ + "type": "extension", + "name": "minimax_v2v_python", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.3.1" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "**.py", + "README.md" + ] + }, + "api": { + "property": { + "token": { + "type": "string" + }, + "max_tokens": { + "type": "int32" + }, + "model": { + "type": "string" + }, + "voice_model": { + "type": "string" + }, + "voice_id": { + "type": "string" + }, + "in_sample_rate": { + "type": "int32" + }, + "out_sample_rate": { + "type": "int32" + }, + "prompt": { + "type": "string" + }, + "greeting": { + "type": "string" + }, + "max_memory_length": { + "type": "int32" + }, + "dump": { + "type": "bool" + } + }, + "cmd_in": [ + { + "name": "flush" + } + ], + "cmd_out": [ + { + "name": "flush" + } + ], + "audio_frame_in": [ + { + "name": "pcm_frame", + "property": { + "stream_id": { + "type": "uint32" + } + } + } + ], + "audio_frame_out": [ + { + "name": "pcm_frame" + } + ], + "data_out": [ + { + "name": "text_data", + "property": { + "text": { + "type": "string" + }, + "is_final": { + "type": "bool" + }, + "end_of_segment": { + "type": "bool" + }, + "role": { + "type": "string" + }, + "stream_id": { + "type": "uint32" + } + } + } + ] + } +} \ No newline at end of file diff --git a/agents/ten_packages/extension/minimax_v2v_python/property.json b/agents/ten_packages/extension/minimax_v2v_python/property.json new file mode 100644 index 00000000..9e26dfee --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/property.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/agents/ten_packages/extension/minimax_v2v_python/requirements.txt b/agents/ten_packages/extension/minimax_v2v_python/requirements.txt new file mode 100644 index 00000000..73f98701 --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/requirements.txt @@ -0,0 +1,2 @@ +aiofiles +httpx \ No newline at end of file diff --git a/agents/ten_packages/extension/minimax_v2v_python/util.py b/agents/ten_packages/extension/minimax_v2v_python/util.py new file mode 100644 index 00000000..f0411910 --- /dev/null +++ b/agents/ten_packages/extension/minimax_v2v_python/util.py @@ -0,0 +1,15 @@ +from datetime import datetime + + +def duration_in_ms(start: datetime, end: datetime) -> int: + return int((end - start).total_seconds() * 1000) + + +def duration_in_ms_since(start: datetime) -> int: + return duration_in_ms(start, datetime.now()) + + +class Role(str): + System = "system" + User = "user" + Assistant = "assistant" diff --git a/demo/src/manager/rtc/rtc.ts b/demo/src/manager/rtc/rtc.ts index 4139d89e..ca7e95de 100644 --- a/demo/src/manager/rtc/rtc.ts +++ b/demo/src/manager/rtc/rtc.ts @@ -12,6 +12,16 @@ import { AGEventEmitter } from "../events" import { RtcEvents, IUserTracks } from "./types" import { apiGenAgoraData } from "@/common" + +const TIMEOUT_MS = 5000; // Timeout for incomplete messages + +interface TextDataChunk { + message_id: string; + part_index: number; + total_parts: number; + content: string; +} + export class RtcManager extends AGEventEmitter { private _joined client: IAgoraRTCClient @@ -110,75 +120,94 @@ export class RtcManager extends AGEventEmitter { private _parseData(data: any): ITextItem | void { let decoder = new TextDecoder('utf-8'); let decodedMessage = decoder.decode(data); - const textstream = JSON.parse(decodedMessage); - console.log("[test] textstream raw data", JSON.stringify(textstream)); + console.log("[test] textstream raw data", decodedMessage); - const { stream_id, is_final, text, text_ts, data_type, message_id, part_number, total_parts } = textstream; + // const { stream_id, is_final, text, text_ts, data_type, message_id, part_number, total_parts } = textstream; - if (total_parts > 0) { - // If message is split, handle it accordingly - this._handleSplitMessage(message_id, part_number, total_parts, stream_id, is_final, text, text_ts); - } else { - // If there is no message_id, treat it as a complete message - this._handleCompleteMessage(stream_id, is_final, text, text_ts); - } + // if (total_parts > 0) { + // // If message is split, handle it accordingly + // this._handleSplitMessage(message_id, part_number, total_parts, stream_id, is_final, text, text_ts); + // } else { + // // If there is no message_id, treat it as a complete message + // this._handleCompleteMessage(stream_id, is_final, text, text_ts); + // } + + this.handleChunk(decodedMessage); } - private messageCache: { [key: string]: { parts: string[], totalParts: number } } = {}; - - /** - * Handle complete messages (not split). - */ - private _handleCompleteMessage(stream_id: number, is_final: boolean, text: string, text_ts: number): void { - const textItem: ITextItem = { - uid: `${stream_id}`, - time: text_ts, - dataType: "transcribe", - text: text, - isFinal: is_final - }; - - if (text.trim().length > 0) { - this.emit("textChanged", textItem); + + private messageCache: { [key: string]: TextDataChunk[] } = {}; + + // Function to process received chunk via event emitter + handleChunk(formattedChunk: string) { + try { + // Split the chunk by the delimiter "|" + const [message_id, partIndexStr, totalPartsStr, content] = formattedChunk.split('|'); + + const part_index = parseInt(partIndexStr, 10); + const total_parts = totalPartsStr === '???' ? -1 : parseInt(totalPartsStr, 10); // -1 means total parts unknown + + // Ensure total_parts is known before processing further + if (total_parts === -1) { + console.warn(`Total parts for message ${message_id} unknown, waiting for further parts.`); + return; + } + + const chunkData: TextDataChunk = { + message_id, + part_index, + total_parts, + content, + }; + + // Check if we already have an entry for this message + if (!this.messageCache[message_id]) { + this.messageCache[message_id] = []; + // Set a timeout to discard incomplete messages + setTimeout(() => { + if (this.messageCache[message_id]?.length !== total_parts) { + console.warn(`Incomplete message with ID ${message_id} discarded`); + delete this.messageCache[message_id]; // Discard incomplete message + } + }, TIMEOUT_MS); + } + + // Cache this chunk by message_id + this.messageCache[message_id].push(chunkData); + + // If all parts are received, reconstruct the message + if (this.messageCache[message_id].length === total_parts) { + const completeMessage = this.reconstructMessage(this.messageCache[message_id]); + const { stream_id, is_final, text, text_ts } = JSON.parse(atob(completeMessage)); + const textItem: ITextItem = { + uid: `${stream_id}`, + time: text_ts, + dataType: "transcribe", + text: text, + isFinal: is_final + }; + + if (text.trim().length > 0) { + this.emit("textChanged", textItem); + } + + + // Clean up the cache + delete this.messageCache[message_id]; + } + } catch (error) { + console.error('Error processing chunk:', error); } } - - /** - * Handle split messages, track parts, and reassemble once all parts are received. - */ - private _handleSplitMessage( - message_id: string, - part_number: number, - total_parts: number, - stream_id: number, - is_final: boolean, - text: string, - text_ts: number - ): void { - // Ensure the messageCache entry exists for this message_id - if (!this.messageCache[message_id]) { - this.messageCache[message_id] = { parts: [], totalParts: total_parts }; - } - - const cache = this.messageCache[message_id]; - - // Store the received part at the correct index (part_number starts from 1, so we use part_number - 1) - cache.parts[part_number - 1] = text; - - // Check if all parts have been received - const receivedPartsCount = cache.parts.filter(part => part !== undefined).length; - - if (receivedPartsCount === total_parts) { - // All parts have been received, reassemble the message - const fullText = cache.parts.join(''); - - // Now that the message is reassembled, handle it like a complete message - this._handleCompleteMessage(stream_id, is_final, fullText, text_ts); - - // Remove the cached message since it is now fully processed - delete this.messageCache[message_id]; - } + + // Function to reconstruct the full message from chunks + reconstructMessage(chunks: TextDataChunk[]): string { + // Sort chunks by their part index + chunks.sort((a, b) => a.part_index - b.part_index); + + // Concatenate all chunks to form the full message + return chunks.map(chunk => chunk.content).join(''); } @@ -196,4 +225,4 @@ export class RtcManager extends AGEventEmitter { } -export const rtcManager = new RtcManager() +export const rtcManager = new RtcManager() \ No newline at end of file