diff --git a/agents/examples/experimental/property.json b/agents/examples/experimental/property.json index a6ceb0e8..b532fea5 100644 --- a/agents/examples/experimental/property.json +++ b/agents/examples/experimental/property.json @@ -369,7 +369,7 @@ { "type": "extension", "extension_group": "tts", - "addon": "elevenlabs_tts", + "addon": "elevenlabs_tts_python", "name": "elevenlabs_tts", "property": { "api_key": "${env:ELEVENLABS_TTS_KEY}", @@ -379,7 +379,6 @@ "similarity_boost": 0.75, "speaker_boost": false, "stability": 0.5, - "style": 0.0, "voice_id": "pNInz6obpgDQGcFmaJgB" } }, @@ -878,7 +877,7 @@ { "type": "extension", "extension_group": "tts", - "addon": "cosy_tts", + "addon": "cosy_tts_python", "name": "cosy_tts", "property": { "api_key": "${env:QWEN_API_KEY}", @@ -1080,7 +1079,7 @@ { "type": "extension", "extension_group": "tts", - "addon": "cosy_tts", + "addon": "cosy_tts_python", "name": "cosy_tts", "property": { "api_key": "${env:QWEN_API_KEY}", @@ -2077,7 +2076,7 @@ { "type": "extension", "extension_group": "tts", - "addon": "cosy_tts", + "addon": "cosy_tts_python", "name": "cosy_tts", "property": { "api_key": "${env:QWEN_API_KEY}", @@ -3651,7 +3650,7 @@ { "type": "extension", "extension_group": "chatgpt", - "addon": "openai_chatgpt", + "addon": "openai_chatgpt_python", "name": "openai_chatgpt", "property": { "base_url": "${env:OPENAI_API_BASE}", diff --git a/agents/ten_packages/extension/cosy_tts/__init__.py b/agents/ten_packages/extension/cosy_tts/__init__.py deleted file mode 100644 index d7a1c8ec..00000000 --- a/agents/ten_packages/extension/cosy_tts/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import cosy_tts_addon - -print("cosy_tts extension loaded") diff --git a/agents/ten_packages/extension/cosy_tts/cosy_tts_addon.py b/agents/ten_packages/extension/cosy_tts/cosy_tts_addon.py deleted file mode 100644 index 2b36e02d..00000000 --- a/agents/ten_packages/extension/cosy_tts/cosy_tts_addon.py +++ /dev/null @@ -1,16 +0,0 @@ -from ten import ( - Addon, - register_addon_as_extension, - TenEnv, -) - - -@register_addon_as_extension("cosy_tts") -class CosyTTSExtensionAddon(Addon): - def on_create_instance(self, ten: TenEnv, addon_name: str, context) -> None: - from .log import logger - logger.info("on_create_instance") - - from .cosy_tts_extension import CosyTTSExtension - - ten.on_create_instance_done(CosyTTSExtension(addon_name), context) diff --git a/agents/ten_packages/extension/cosy_tts/cosy_tts_extension.py b/agents/ten_packages/extension/cosy_tts/cosy_tts_extension.py deleted file mode 100644 index 647341ad..00000000 --- a/agents/ten_packages/extension/cosy_tts/cosy_tts_extension.py +++ /dev/null @@ -1,251 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by Wei Hu in 2024-05. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# -import traceback -from ten import ( - Extension, - TenEnv, - Cmd, - AudioFrame, - AudioFrameDataFmt, - Data, - StatusCode, - CmdResult, -) -from typing import List, Any -import dashscope -import queue -import threading -from datetime import datetime -from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat -from .log import logger - - -class CosyTTSCallback(ResultCallback): - def __init__(self, ten: TenEnv, sample_rate: int, need_interrupt_callback): - super().__init__() - self.ten = ten - self.sample_rate = sample_rate - self.frame_size = int(self.sample_rate * 1 * 2 / 100) - self.ts = datetime.now() # current task ts - self.init_ts = datetime.now() - self.ttfb = None # time to first byte - self.need_interrupt_callback = need_interrupt_callback - self.closed = False - - def need_interrupt(self) -> bool: - return self.need_interrupt_callback(self.ts) - - def set_input_ts(self, ts: datetime): - self.ts = ts - - def on_open(self): - logger.info("websocket is open.") - - def on_complete(self): - logger.info("speech synthesis task complete successfully.") - - def on_error(self, message: str): - logger.info(f"speech synthesis task failed, {message}") - - def on_close(self): - logger.info("websocket is closed.") - self.closed = True - - def on_event(self, message): - pass - # logger.info(f"recv speech synthsis message {message}") - - def get_frame(self, data: bytes) -> AudioFrame: - f = AudioFrame.create("pcm_frame") - f.set_sample_rate(self.sample_rate) - f.set_bytes_per_sample(2) - f.set_number_of_channels(1) - # f.set_timestamp = 0 - f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) - f.set_samples_per_channel(len(data) // 2) - f.alloc_buf(len(data)) - buff = f.lock_buf() - buff[:] = data - f.unlock_buf(buff) - return f - - def on_data(self, data: bytes) -> None: - if self.need_interrupt(): - return - if self.ttfb is None: - self.ttfb = datetime.now() - self.init_ts - logger.info("TTS TTFB {}ms".format(int(self.ttfb.total_seconds() * 1000))) - - # logger.info("audio result length: %d, %d", len(data), self.frame_size) - try: - f = self.get_frame(data) - self.ten.send_audio_frame(f) - except Exception as e: - logger.exception(e) - - -class CosyTTSExtension(Extension): - def __init__(self, name: str): - super().__init__(name) - self.api_key = "" - self.voice = "" - self.model = "" - self.sample_rate = 16000 - self.tts = None - self.callback = None - self.format = None - - self.outdate_ts = datetime.now() - - self.stopped = False - self.thread = None - self.queue = queue.Queue() - - def on_start(self, ten: TenEnv) -> None: - logger.info("on_start") - self.api_key = ten.get_property_string("api_key") - self.voice = ten.get_property_string("voice") - self.model = ten.get_property_string("model") - self.sample_rate = ten.get_property_int("sample_rate") - - dashscope.api_key = self.api_key - f = AudioFormat.PCM_16000HZ_MONO_16BIT - if self.sample_rate == 8000: - f = AudioFormat.PCM_8000HZ_MONO_16BIT - elif self.sample_rate == 16000: - f = AudioFormat.PCM_16000HZ_MONO_16BIT - elif self.sample_rate == 22050: - f = AudioFormat.PCM_22050HZ_MONO_16BIT - elif self.sample_rate == 24000: - f = AudioFormat.PCM_24000HZ_MONO_16BIT - elif self.sample_rate == 44100: - f = AudioFormat.PCM_44100HZ_MONO_16BIT - elif self.sample_rate == 48000: - f = AudioFormat.PCM_48000HZ_MONO_16BIT - else: - logger.error("unknown sample rate %d", self.sample_rate) - exit() - - self.format = f - - self.thread = threading.Thread(target=self.async_handle, args=[ten]) - self.thread.start() - ten.on_start_done() - - def on_stop(self, ten: TenEnv) -> None: - logger.info("on_stop") - - self.stopped = True - self.flush() - self.queue.put(None) - if self.thread is not None: - self.thread.join() - self.thread = None - ten.on_stop_done() - - def need_interrupt(self, ts: datetime.time) -> bool: - return self.outdate_ts > ts - - def async_handle(self, ten: TenEnv): - try: - tts = None - callback = None - while not self.stopped: - try: - value = self.queue.get() - if value is None: - break - input_text, ts, end_of_segment = value - - # clear tts if old one is closed already - if callback is not None and callback.closed is True: - tts = None - callback = None - - # cancel last streaming call to avoid unprocessed audio coming back - if ( - callback is not None - and tts is not None - and callback.need_interrupt() - ): - tts.streaming_cancel() - tts = None - callback = None - - if self.need_interrupt(ts): - logger.info("drop outdated input") - continue - - # create new tts if needed - if tts is None or callback is None: - logger.info("creating tts") - callback = CosyTTSCallback( - ten, self.sample_rate, self.need_interrupt - ) - tts = SpeechSynthesizer( - model=self.model, - voice=self.voice, - format=self.format, - callback=callback, - ) - - logger.info( - "on message [{}] ts [{}] end_of_segment [{}]".format( - input_text, ts, end_of_segment - ) - ) - - # make sure new data won't be marked as outdated - callback.set_input_ts(ts) - - if len(input_text) > 0: - # last segment may have empty text but is_end is true - tts.streaming_call(input_text) - - # complete the streaming call to drain remained audio - if True: # end_of_segment: - try: - tts.streaming_complete() - except Exception as e: - logger.warning(e) - tts = None - callback = None - except Exception as e: - logger.exception(e) - logger.exception(traceback.format_exc()) - finally: - if tts is not None: - tts.streaming_cancel() - tts = None - callback = None - - def flush(self): - while not self.queue.empty(): - self.queue.get() - - def on_data(self, ten: TenEnv, data: Data) -> None: - inputText = data.get_property_string("text") - end_of_segment = data.get_property_bool("end_of_segment") - - logger.info("on data {} {}".format(inputText, end_of_segment)) - self.queue.put((inputText, datetime.now(), end_of_segment)) - - def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: - cmd_name = cmd.get_name() - logger.info("on_cmd {}".format(cmd_name)) - if cmd_name == "flush": - self.outdate_ts = datetime.now() - self.flush() - cmd_out = Cmd.create("flush") - ten.send_cmd(cmd_out, lambda ten, result: print("send_cmd flush done")) - else: - logger.info("unknown cmd {}".format(cmd_name)) - - cmd_result = CmdResult.create(StatusCode.OK) - cmd_result.set_property_string("detail", "success") - ten.return_result(cmd_result, cmd) diff --git a/agents/ten_packages/extension/cosy_tts/log.py b/agents/ten_packages/extension/cosy_tts/log.py deleted file mode 100644 index 83e89596..00000000 --- a/agents/ten_packages/extension/cosy_tts/log.py +++ /dev/null @@ -1,13 +0,0 @@ -import logging - -logger = logging.getLogger("COSY_TTS") -logger.setLevel(logging.INFO) - -formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s" -) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/cosy_tts/manifest.json b/agents/ten_packages/extension/cosy_tts/manifest.json deleted file mode 100644 index e330429f..00000000 --- a/agents/ten_packages/extension/cosy_tts/manifest.json +++ /dev/null @@ -1,53 +0,0 @@ -{ - "type": "extension", - "name": "cosy_tts", - "version": "0.4.0", - "dependencies": [ - { - "type": "system", - "name": "ten_runtime_python", - "version": "0.4" - } - ], - "api": { - "property": { - "api_key": { - "type": "string" - }, - "voice": { - "type": "string" - }, - "model": { - "type": "string" - }, - "sample_rate": { - "type": "int64" - } - }, - "data_in": [ - { - "name": "text_data", - "property": { - "text": { - "type": "string" - } - } - } - ], - "cmd_in": [ - { - "name": "flush" - } - ], - "cmd_out": [ - { - "name": "flush" - } - ], - "audio_frame_out": [ - { - "name": "pcm_frame" - } - ] - } -} \ No newline at end of file diff --git a/agents/ten_packages/extension/cosy_tts/requirements.txt b/agents/ten_packages/extension/cosy_tts/requirements.txt deleted file mode 100644 index f1c09c9e..00000000 --- a/agents/ten_packages/extension/cosy_tts/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -dashscope==1.20.0 \ No newline at end of file diff --git a/agents/ten_packages/extension/cosy_tts_python/BUILD.gn b/agents/ten_packages/extension/cosy_tts_python/BUILD.gn new file mode 100644 index 00000000..40bb6dd4 --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/BUILD.gn @@ -0,0 +1,19 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import("//build/feature/ten_package.gni") + +ten_package("cosy_tts_python") { + package_kind = "extension" + + resources = [ + "__init__.py", + "addon.py", + "extension.py", + "manifest.json", + "property.json", + "tests", + ] +} diff --git a/agents/ten_packages/extension/cosy_tts_python/README.md b/agents/ten_packages/extension/cosy_tts_python/README.md new file mode 100644 index 00000000..2f0cd08f --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/README.md @@ -0,0 +1,29 @@ +# cosy_tts_python + + + +## Features + + + +- xxx feature + +## API + +Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). + + + +## Development + +### Build + + + +### Unit test + + + +## Misc + + diff --git a/agents/ten_packages/extension/cosy_tts_python/__init__.py b/agents/ten_packages/extension/cosy_tts_python/__init__.py new file mode 100644 index 00000000..72593ab2 --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/__init__.py @@ -0,0 +1,6 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon diff --git a/agents/ten_packages/extension/cosy_tts_python/addon.py b/agents/ten_packages/extension/cosy_tts_python/addon.py new file mode 100644 index 00000000..ad4d8df0 --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/addon.py @@ -0,0 +1,19 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("cosy_tts_python") +class CosyTTSExtensionAddon(Addon): + + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import CosyTTSExtension + ten_env.log_info("CosyTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(CosyTTSExtension(name), context) diff --git a/agents/ten_packages/extension/cosy_tts_python/cosy_tts.py b/agents/ten_packages/extension/cosy_tts_python/cosy_tts.py new file mode 100644 index 00000000..65e3e866 --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/cosy_tts.py @@ -0,0 +1,108 @@ +import asyncio +from dataclasses import dataclass +from typing import AsyncIterator + +from websocket import WebSocketConnectionClosedException + +from ten.async_ten_env import AsyncTenEnv +from ten_ai_base.config import BaseConfig +import dashscope +from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse +from dashscope.audio.tts_v2 import * + + +@dataclass +class CosyTTSConfig(BaseConfig): + api_key: str = "" + voice: str = "longxiaochun" + model: str = "cosyvoice-v1" + sample_rate: int = 16000 + + +class AsyncIteratorCallback(ResultCallback): + def __init__(self, ten_env: AsyncTenEnv, queue: asyncio.Queue) -> None: + self.closed = False + self.ten_env = ten_env + self.loop = asyncio.get_event_loop() + self.queue = queue + + def close(self): + self.closed = True + + def on_open(self): + self.ten_env.log_info("websocket is open.") + + def on_complete(self): + self.ten_env.log_info("speech synthesis task complete successfully.") + + def on_error(self, message: str): + self.ten_env.log_error(f"speech synthesis task failed, {message}") + + def on_close(self): + self.ten_env.log_info("websocket is closed.") + self.close() + + def on_event(self, message: str) -> None: + self.ten_env.log_debug(f"received event: {message}") + + def on_data(self, data: bytes) -> None: + if self.closed: + self.ten_env.log_warn(f"received data: {len(data)} bytes but connection was closed") + return + self.ten_env.log_debug(f"received data: {len(data)} bytes") + asyncio.run_coroutine_threadsafe(self.queue.put(data), self.loop) + + +class CosyTTS: + def __init__(self, config: CosyTTSConfig) -> None: + self.config = config + self.synthesizer = None # Initially no synthesizer + self.queue = asyncio.Queue() + dashscope.api_key = config.api_key + + def _create_synthesizer(self, ten_env: AsyncTenEnv, callback: AsyncIteratorCallback): + if self.synthesizer: + self.synthesizer = None + + ten_env.log_info("Creating new synthesizer") + self.synthesizer = SpeechSynthesizer( + model=self.config.model, + voice=self.config.voice, + format=AudioFormat.PCM_16000HZ_MONO_16BIT, + callback=callback, + ) + + async def get_audio_bytes(self) -> bytes: + return await self.queue.get() + + def text_to_speech_stream( + self, ten_env: AsyncTenEnv, text: str, end_of_segment: bool + ) -> None: + try: + callback = AsyncIteratorCallback(ten_env, self.queue) + + if not self.synthesizer or end_of_segment: + self._create_synthesizer(ten_env, callback) + + self.synthesizer.streaming_call(text) + + if end_of_segment: + ten_env.log_info("Streaming complete") + self.synthesizer.streaming_complete() + self.synthesizer = None + except WebSocketConnectionClosedException as e: + ten_env.log_error(f"WebSocket connection closed, {e}") + self.synthesizer = None + except Exception as e: + ten_env.log_error(f"Error streaming text, {e}") + self.synthesizer = None + + def cancel(self, ten_env: AsyncTenEnv) -> None: + if self.synthesizer: + try: + self.synthesizer.streaming_cancel() + except WebSocketConnectionClosedException as e: + ten_env.log_error(f"WebSocket connection closed, {e}") + except Exception as e: + ten_env.log_error(f"Error cancelling streaming, {e}") + self.synthesizer = None diff --git a/agents/ten_packages/extension/cosy_tts_python/extension.py b/agents/ten_packages/extension/cosy_tts_python/extension.py new file mode 100644 index 00000000..63041186 --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/extension.py @@ -0,0 +1,59 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from asyncio import sleep +import asyncio +from .cosy_tts import CosyTTS, CosyTTSConfig +from ten import ( + AsyncTenEnv, +) +from ten_ai_base.tts import AsyncTTSBaseExtension + + +class CosyTTSExtension(AsyncTTSBaseExtension): + def __init__(self, name: str) -> None: + super().__init__(name) + self.client = None + self.config = None + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) + ten_env.log_debug("on_init") + + async def on_start(self, ten_env: AsyncTenEnv) -> None: + await super().on_start(ten_env) + ten_env.log_debug("on_start") + + self.config = CosyTTSConfig.create(ten_env=ten_env) + self.client = CosyTTS(self.config) + + asyncio.create_task(self._process_audio_data(ten_env)) + + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + await super().on_stop(ten_env) + ten_env.log_debug("on_stop") + + await self.queue.put(None) + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) + ten_env.log_debug("on_deinit") + + async def _process_audio_data(self, ten_env: AsyncTenEnv) -> None: + while True: + audio_data = await self.client.get_audio_bytes() + + if audio_data is None: + break + + self.send_audio_out(ten_env, audio_data) + + + async def on_request_tts(self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool) -> None: + self.client.text_to_speech_stream(ten_env, input_text, end_of_segment) + + async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None: + self.client.cancel(ten_env) diff --git a/agents/ten_packages/extension/cosy_tts_python/manifest.json b/agents/ten_packages/extension/cosy_tts_python/manifest.json new file mode 100644 index 00000000..d32d2d2c --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/manifest.json @@ -0,0 +1,64 @@ +{ + "type": "extension", + "name": "cosy_tts_python", + "version": "0.4.2", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.4.2" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "BUILD.gn", + "**.tent", + "**.py", + "README.md", + "tests/**" + ] + }, + "api": { + "property": { + "api_key": { + "type": "string" + }, + "voice": { + "type": "string" + }, + "model": { + "type": "string" + }, + "sample_rate": { + "type": "int64" + } + }, + "data_in": [ + { + "name": "text_data", + "property": { + "text": { + "type": "string" + } + } + } + ], + "cmd_in": [ + { + "name": "flush" + } + ], + "cmd_out": [ + { + "name": "flush" + } + ], + "audio_frame_out": [ + { + "name": "pcm_frame" + } + ] + } +} \ No newline at end of file diff --git a/agents/ten_packages/extension/cosy_tts/property.json b/agents/ten_packages/extension/cosy_tts_python/property.json similarity index 100% rename from agents/ten_packages/extension/cosy_tts/property.json rename to agents/ten_packages/extension/cosy_tts_python/property.json diff --git a/agents/ten_packages/extension/cosy_tts_python/requirements.txt b/agents/ten_packages/extension/cosy_tts_python/requirements.txt new file mode 100644 index 00000000..5899464f --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/requirements.txt @@ -0,0 +1 @@ +dashscope \ No newline at end of file diff --git a/agents/ten_packages/extension/cosy_tts_python/tests/test_basic.py b/agents/ten_packages/extension/cosy_tts_python/tests/test_basic.py new file mode 100644 index 00000000..c3755f44 --- /dev/null +++ b/agents/ten_packages/extension/cosy_tts_python/tests/test_basic.py @@ -0,0 +1,36 @@ +# +# Copyright © 2024 Agora +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0, with certain conditions. +# Refer to the "LICENSE" file in the root directory for more information. +# +from pathlib import Path +from ten import ExtensionTester, TenEnvTester, Cmd, CmdResult, StatusCode + + +class ExtensionTesterBasic(ExtensionTester): + def check_hello(self, ten_env: TenEnvTester, result: CmdResult): + statusCode = result.get_status_code() + print("receive hello_world, status:" + str(statusCode)) + + if statusCode == StatusCode.OK: + ten_env.stop_test() + + def on_start(self, ten_env: TenEnvTester) -> None: + new_cmd = Cmd.create("hello_world") + + print("send hello_world") + ten_env.send_cmd( + new_cmd, + lambda ten_env, result: self.check_hello(ten_env, result), + ) + + print("tester on_start_done") + ten_env.on_start_done() + + +def test_basic(): + tester = ExtensionTesterBasic() + tester.add_addon_base_dir(str(Path(__file__).resolve().parent.parent)) + tester.set_test_mode_single("default_async_extension_python") + tester.run() diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/BUILD.gn b/agents/ten_packages/extension/elevenlabs_tts_python/BUILD.gn new file mode 100644 index 00000000..008eed3d --- /dev/null +++ b/agents/ten_packages/extension/elevenlabs_tts_python/BUILD.gn @@ -0,0 +1,19 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import("//build/feature/ten_package.gni") + +ten_package("elevenlabs_tts_python") { + package_kind = "extension" + + resources = [ + "__init__.py", + "addon.py", + "extension.py", + "manifest.json", + "property.json", + "tests", + ] +} diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/README.md b/agents/ten_packages/extension/elevenlabs_tts_python/README.md new file mode 100644 index 00000000..e6032c06 --- /dev/null +++ b/agents/ten_packages/extension/elevenlabs_tts_python/README.md @@ -0,0 +1,29 @@ +# elevenlabs_tts_python + + + +## Features + + + +- xxx feature + +## API + +Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). + + + +## Development + +### Build + + + +### Unit test + + + +## Misc + + diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/__init__.py b/agents/ten_packages/extension/elevenlabs_tts_python/__init__.py index d7fe7e64..72593ab2 100644 --- a/agents/ten_packages/extension/elevenlabs_tts_python/__init__.py +++ b/agents/ten_packages/extension/elevenlabs_tts_python/__init__.py @@ -1,6 +1,6 @@ -from . import elevenlabs_tts_addon -from .extension import EXTENSION_NAME -from .log import logger - - -logger.info(f"{EXTENSION_NAME} extension loaded") +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/addon.py b/agents/ten_packages/extension/elevenlabs_tts_python/addon.py new file mode 100644 index 00000000..af96068d --- /dev/null +++ b/agents/ten_packages/extension/elevenlabs_tts_python/addon.py @@ -0,0 +1,19 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("elevenlabs_tts_python") +class ElevenLabsTTSExtensionAddon(Addon): + + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import ElevenLabsTTSExtension + ten_env.log_info("ElevenLabsTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(ElevenLabsTTSExtension(name), context) diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts.py b/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts.py index 282024fc..e87761fd 100644 --- a/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts.py +++ b/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts.py @@ -6,46 +6,33 @@ # # -from typing import Iterator +from dataclasses import dataclass +from typing import AsyncIterator, Iterator from elevenlabs import Voice, VoiceSettings -from elevenlabs.client import ElevenLabs +from elevenlabs.client import AsyncElevenLabs +from ten_ai_base.config import BaseConfig -class ElevenlabsTTSConfig: - def __init__( - self, - api_key="", - model_id="eleven_multilingual_v2", - optimize_streaming_latency=0, - request_timeout_seconds=30, - similarity_boost=0.75, - speaker_boost=False, - stability=0.5, - style=0.0, - voice_id="pNInz6obpgDQGcFmaJgB", - ) -> None: - self.api_key = api_key - self.model_id = model_id - self.optimize_streaming_latency = optimize_streaming_latency - self.request_timeout_seconds = request_timeout_seconds - self.similarity_boost = similarity_boost - self.speaker_boost = speaker_boost - self.stability = stability - self.style = style - self.voice_id = voice_id +@dataclass +class ElevenLabsTTSConfig(BaseConfig): + api_key: str = "" + model_id: str = "eleven_multilingual_v2" + optimize_streaming_latency: int = 0 + similarity_boost: float = 0.75 + speaker_boost: bool = False + stability: float = 0.5 + request_timeout_seconds: int = 10 + style: float = 0.0 + voice_id: str = "pNInz6obpgDQGcFmaJgB" -def default_elevenlabs_tts_config() -> ElevenlabsTTSConfig: - return ElevenlabsTTSConfig() - - -class ElevenlabsTTS: - def __init__(self, config: ElevenlabsTTSConfig) -> None: +class ElevenLabsTTS: + def __init__(self, config: ElevenLabsTTSConfig) -> None: self.config = config - self.client = ElevenLabs(api_key=config.api_key, timeout=config.request_timeout_seconds) + self.client = AsyncElevenLabs(api_key=config.api_key, timeout=config.request_timeout_seconds) - def text_to_speech_stream(self, text: str) -> Iterator[bytes]: - audio_stream = self.client.generate( + def text_to_speech_stream(self, text: str) -> AsyncIterator[bytes]: + return self.client.generate( text=text, model=self.config.model_id, optimize_streaming_latency=self.config.optimize_streaming_latency, @@ -60,6 +47,4 @@ def text_to_speech_stream(self, text: str) -> Iterator[bytes]: speaker_boost=self.config.speaker_boost, ), ), - ) - - return audio_stream + ) \ No newline at end of file diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts_addon.py b/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts_addon.py deleted file mode 100644 index a6361f63..00000000 --- a/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts_addon.py +++ /dev/null @@ -1,24 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by XinHui Li in 2024-07. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# - -from ten import ( - Addon, - register_addon_as_extension, - TenEnv, -) -from .extension import EXTENSION_NAME - - -@register_addon_as_extension(EXTENSION_NAME) -class ElevenlabsTTSExtensionAddon(Addon): - def on_create_instance(self, ten: TenEnv, addon_name: str, context) -> None: - from .log import logger - logger.info("on_create_instance") - from .elevenlabs_tts_extension import ElevenlabsTTSExtension - - ten.on_create_instance_done(ElevenlabsTTSExtension(addon_name), context) diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts_extension.py b/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts_extension.py deleted file mode 100644 index 1e47a359..00000000 --- a/agents/ten_packages/extension/elevenlabs_tts_python/elevenlabs_tts_extension.py +++ /dev/null @@ -1,228 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by XinHui Li in 2024-07. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# - -import queue -import threading -import time - -from ten import ( - Extension, - TenEnv, - Cmd, - CmdResult, - StatusCode, - Data, -) -from .elevenlabs_tts import default_elevenlabs_tts_config, ElevenlabsTTS -from .pcm import PcmConfig, Pcm -from .log import logger - -CMD_IN_FLUSH = "flush" -CMD_OUT_FLUSH = "flush" - -DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" - -PROPERTY_API_KEY = "api_key" # Required -PROPERTY_MODEL_ID = "model_id" # Optional -PROPERTY_OPTIMIZE_STREAMING_LATENCY = "optimize_streaming_latency" # Optional -PROPERTY_REQUEST_TIMEOUT_SECONDS = "request_timeout_seconds" # Optional -PROPERTY_SIMILARITY_BOOST = "similarity_boost" # Optional -PROPERTY_SPEAKER_BOOST = "speaker_boost" # Optional -PROPERTY_STABILITY = "stability" # Optional -PROPERTY_STYLE = "style" # Optional - - -class Message: - def __init__(self, text: str, received_ts: int) -> None: - self.text = text - self.received_ts = received_ts - - -class ElevenlabsTTSExtension(Extension): - def on_start(self, ten: TenEnv) -> None: - logger.info("on_start") - - self.elevenlabs_tts = None - self.outdate_ts = 0 - self.pcm = None - self.pcm_frame_size = 0 - self.text_queue = queue.Queue(maxsize=1024) - - # prepare configuration - elevenlabs_tts_config = default_elevenlabs_tts_config() - - try: - elevenlabs_tts_config.api_key = ten.get_property_string(PROPERTY_API_KEY) - except Exception as e: - logger.warning(f"on_start get_property_string {PROPERTY_API_KEY} error: {e}") - return - - try: - model_id = ten.get_property_string(PROPERTY_MODEL_ID) - if len(model_id) > 0: - elevenlabs_tts_config.model_id = model_id - except Exception as e: - logger.warning(f"on_start get_property_string {PROPERTY_MODEL_ID} error: {e}") - - try: - optimize_streaming_latency = ten.get_property_int(PROPERTY_OPTIMIZE_STREAMING_LATENCY) - if optimize_streaming_latency > 0: - elevenlabs_tts_config.optimize_streaming_latency = optimize_streaming_latency - except Exception as e: - logger.warning(f"on_start get_property_int {PROPERTY_OPTIMIZE_STREAMING_LATENCY} error: {e}") - - try: - request_timeout_seconds = ten.get_property_int(PROPERTY_REQUEST_TIMEOUT_SECONDS) - if request_timeout_seconds > 0: - elevenlabs_tts_config.request_timeout_seconds = request_timeout_seconds - except Exception as e: - logger.warning(f"on_start get_property_int {PROPERTY_REQUEST_TIMEOUT_SECONDS} error: {e}") - - try: - elevenlabs_tts_config.similarity_boost = ten.get_property_float(PROPERTY_SIMILARITY_BOOST) - except Exception as e: - logger.warning(f"on_start get_property_float {PROPERTY_SIMILARITY_BOOST} error: {e}") - - try: - elevenlabs_tts_config.speaker_boost = ten.get_property_bool(PROPERTY_SPEAKER_BOOST) - except Exception as e: - logger.warning(f"on_start get_property_bool {PROPERTY_SPEAKER_BOOST} error: {e}") - - try: - elevenlabs_tts_config.stability = ten.get_property_float(PROPERTY_STABILITY) - except Exception as e: - logger.warning(f"on_start get_property_float {PROPERTY_STABILITY} error: {e}") - - try: - elevenlabs_tts_config.style = ten.get_property_float(PROPERTY_STYLE) - except Exception as e: - logger.warning(f"on_start get_property_float {PROPERTY_STYLE} error: {e}") - - # create elevenlabsTTS instance - self.elevenlabs_tts = ElevenlabsTTS(elevenlabs_tts_config) - - logger.info(f"ElevenlabsTTS succeed with model_id: {self.elevenlabs_tts.config.model_id}, VoiceId: {self.elevenlabs_tts.config.voice_id}") - - # create pcm instance - self.pcm = Pcm(PcmConfig()) - self.pcm_frame_size = self.pcm.get_pcm_frame_size() - - threading.Thread(target=self.process_text_queue, args=(ten,)).start() - - ten.on_start_done() - - def on_stop(self, ten: TenEnv) -> None: - logger.info("on_stop") - ten.on_stop_done() - - def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: - """ - on_cmd receives cmd from ten graph. - current supported cmd: - - name: flush - example: - {"name": "flush"} - """ - logger.info("on_cmd") - cmd_name = cmd.get_name() - - logger.info(f"on_cmd [{cmd_name}]") - - if cmd_name is CMD_IN_FLUSH: - self.outdate_ts = int(time.time() * 1000000) - - # send out - out_cmd = Cmd.create(CMD_OUT_FLUSH) - ten.send_cmd(out_cmd) - - cmd_result = CmdResult.create(StatusCode.OK) - cmd_result.set_property_string("detail", "success") - ten.return_result(cmd_result, cmd) - - def on_data(self, ten: TenEnv, data: Data) -> None: - """ - on_data receives data from ten graph. - current supported data: - - name: text_data - example: - {name: text_data, properties: {text: "hello"} - """ - logger.info("on_data") - - try: - text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) - except Exception as e: - logger.warning(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} error: {e}") - return - - if len(text) == 0: - logger.debug("on_data text is empty, ignored") - return - - logger.info(f"OnData input text: [{text}]") - - self.text_queue.put(Message(text, int(time.time() * 1000000))) - - def process_text_queue(self, ten: TenEnv): - logger.info("process_text_queue") - - while True: - msg = self.text_queue.get() - logger.debug(f"process_text_queue, text: [{msg.text}]") - - if msg.received_ts < self.outdate_ts: - logger.info(f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}") - continue - - start_time = time.time() - buf = self.pcm.new_buf() - first_frame_latency = 0 - n = 0 - pcm_frame_read = 0 - read_bytes = 0 - sent_frames = 0 - - audio_stream = self.elevenlabs_tts.text_to_speech_stream(msg.text) - - for chunk in self.pcm.read_pcm_stream(audio_stream, self.pcm_frame_size): - if msg.received_ts < self.outdate_ts: - logger.info(f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}") - break - - if not chunk: - logger.info("read pcm stream EOF") - break - - n = len(chunk) - read_bytes += n - pcm_frame_read += n - - if pcm_frame_read != self.pcm.get_pcm_frame_size(): - logger.debug(f"the number of bytes read is [{pcm_frame_read}] inconsistent with pcm frame size") - continue - - self.pcm.send(ten, buf) - buf = self.pcm.new_buf() - pcm_frame_read = 0 - sent_frames += 1 - - if first_frame_latency == 0: - first_frame_latency = int((time.time() - start_time) * 1000) - logger.info(f"first frame available for text: [{msg.text}], received_ts: {msg.received_ts}, first_frame_latency: {first_frame_latency}ms") - - logger.debug(f"sending pcm data, text: [{msg.text}]") - - if pcm_frame_read > 0: - self.pcm.send(ten, buf) - sent_frames += 1 - logger.info(f"sending pcm remain data, text: [{msg.text}], pcm_frame_read: {pcm_frame_read}") - - finish_latency = int((time.time() - start_time) * 1000) - logger.info(f"send pcm data finished, text: [{msg.text}], received_ts: {msg.received_ts}, read_bytes: {read_bytes}, sent_frames: {sent_frames}, - first_frame_latency: {first_frame_latency}ms, finish_latency: {finish_latency}ms" - ) diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/extension.py b/agents/ten_packages/extension/elevenlabs_tts_python/extension.py index b634e157..01dd44b3 100644 --- a/agents/ten_packages/extension/elevenlabs_tts_python/extension.py +++ b/agents/ten_packages/extension/elevenlabs_tts_python/extension.py @@ -1 +1,53 @@ -EXTENSION_NAME = "elevenlabs_tts_python" +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import traceback + +from ten_ai_base.helper import PCMWriter +from .elevenlabs_tts import ElevenLabsTTS, ElevenLabsTTSConfig +from ten import ( + AsyncTenEnv, +) +from ten_ai_base.tts import AsyncTTSBaseExtension + +class ElevenLabsTTSExtension(AsyncTTSBaseExtension): + def __init__(self, name: str) -> None: + super().__init__(name) + self.config = None + self.client = None + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) + ten_env.log_debug("on_init") + + async def on_start(self, ten_env: AsyncTenEnv) -> None: + try: + await super().on_start(ten_env) + ten_env.log_debug("on_start") + self.config = ElevenLabsTTSConfig.create(ten_env=ten_env) + + if not self.config.api_key: + raise ValueError("api_key is required") + + self.client = ElevenLabsTTS(self.config) + except Exception as err: + ten_env.log_error(f"on_start failed: {traceback.format_exc()}") + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + await super().on_stop(ten_env) + ten_env.log_debug("on_stop") + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) + ten_env.log_debug("on_deinit") + + async def on_request_tts(self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool) -> None: + audio_stream = await self.client.text_to_speech_stream(input_text) + + async for audio_data in audio_stream: + self.send_audio_out(ten_env, audio_data) + + async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None: + return await super().on_cancel_tts(ten_env) \ No newline at end of file diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/log.py b/agents/ten_packages/extension/elevenlabs_tts_python/log.py deleted file mode 100644 index fad21710..00000000 --- a/agents/ten_packages/extension/elevenlabs_tts_python/log.py +++ /dev/null @@ -1,12 +0,0 @@ -import logging -from .extension import EXTENSION_NAME - -logger = logging.getLogger(EXTENSION_NAME) -logger.setLevel(logging.INFO) - -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s") - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/manifest.json b/agents/ten_packages/extension/elevenlabs_tts_python/manifest.json index 67414016..2a24022c 100644 --- a/agents/ten_packages/extension/elevenlabs_tts_python/manifest.json +++ b/agents/ten_packages/extension/elevenlabs_tts_python/manifest.json @@ -1,68 +1,79 @@ { - "type": "extension", - "name": "elevenlabs_tts_python", - "version": "0.4.0", - "dependencies": [ - { - "type": "system", - "name": "ten_runtime_python", - "version": "0.4" + "type": "extension", + "name": "elevenlabs_tts_python", + "version": "0.4.2", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.4.2" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "BUILD.gn", + "**.tent", + "**.py", + "README.md", + "tests/**" + ] + }, + "api": { + "property": { + "api_key": { + "type": "string" + }, + "model_id": { + "type": "string" + }, + "request_timeout_seconds": { + "type": "int64" + }, + "similarity_boost": { + "type": "float64" + }, + "speaker_boost": { + "type": "bool" + }, + "stability": { + "type": "float64" + }, + "style": { + "type": "float64" + }, + "optimize_streaming_latency": { + "type": "int64" + }, + "voice_id": { + "type": "string" + } + }, + "data_in": [ + { + "name": "text_data", + "property": { + "text": { + "type": "string" + } } + } ], - "api": { - "property": { - "api_key": { - "type": "string" - }, - "model_id": { - "type": "string" - }, - "request_timeout_seconds": { - "type": "int64" - }, - "similarity_boost": { - "type": "float64" - }, - "speaker_boost": { - "type": "bool" - }, - "stability": { - "type": "float64" - }, - "style": { - "type": "float64" - }, - "optimize_streaming_latency": { - "type": "int64" - }, - "voice_id": { - "type": "string" - } - }, - "data_in": [ - { - "name": "text_data", - "property": { - "text": { - "type": "string" - } - } - } - ], - "cmd_in": [ - { - "name": "flush" - } - ], - "cmd_out": [ - { - "name": "flush" - } - ], - "audio_frame_out": [ - { - "name": "pcm_frame" - } - ] - } + "cmd_in": [ + { + "name": "flush" + } + ], + "cmd_out": [ + { + "name": "flush" + } + ], + "audio_frame_out": [ + { + "name": "pcm_frame" + } + ] + } } \ No newline at end of file diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/pcm.py b/agents/ten_packages/extension/elevenlabs_tts_python/pcm.py deleted file mode 100644 index 6f0bf493..00000000 --- a/agents/ten_packages/extension/elevenlabs_tts_python/pcm.py +++ /dev/null @@ -1,67 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by XinHui Li in 2024-07. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# - -import logging -from typing import Iterator -from ten import AudioFrame, TenEnv, AudioFrameDataFmt - - -class Pcm: - def __init__(self, config) -> None: - self.config = config - - def get_pcm_frame(self, buf: memoryview) -> AudioFrame: - frame = AudioFrame.create(self.config.name) - frame.set_bytes_per_sample(self.config.bytes_per_sample) - frame.set_sample_rate(self.config.sample_rate) - frame.set_number_of_channels(self.config.num_channels) - frame.set_timestamp(self.config.timestamp) - frame.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) - frame.set_samples_per_channel(self.config.samples_per_channel // self.config.channel) - - frame.alloc_buf(self.get_pcm_frame_size()) - frame_buf = frame.lock_buf() - # copy data - frame_buf[:] = buf - frame.unlock_buf(frame_buf) - - return frame - - def get_pcm_frame_size(self) -> int: - return (self.config.samples_per_channel * self.config.channel * self.config.bytes_per_sample) - - def new_buf(self) -> bytearray: - return bytearray(self.get_pcm_frame_size()) - - def read_pcm_stream(self, stream: Iterator[bytes], chunk_size: int) -> Iterator[bytes]: - chunk = b"" - for data in stream: - chunk += data - while len(chunk) >= chunk_size: - yield chunk[:chunk_size] - chunk = chunk[chunk_size:] - - if chunk: - yield chunk - - def send(self, ten: TenEnv, buf: memoryview) -> None: - try: - frame = self.get_pcm_frame(buf) - ten.send_audio_frame(frame) - except Exception as e: - logging.error(f"send frame failed, {e}") - - -class PcmConfig: - def __init__(self) -> None: - self.bytes_per_sample = 2 - self.channel = 1 - self.name = "pcm_frame" - self.sample_rate = 16000 - self.samples_per_channel = 16000 // 100 - self.timestamp = 0 diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/requirements.txt b/agents/ten_packages/extension/elevenlabs_tts_python/requirements.txt index 3dff4c02..5e8e39a8 100644 --- a/agents/ten_packages/extension/elevenlabs_tts_python/requirements.txt +++ b/agents/ten_packages/extension/elevenlabs_tts_python/requirements.txt @@ -1 +1 @@ -elevenlabs==1.4.1 +elevenlabs \ No newline at end of file diff --git a/agents/ten_packages/extension/elevenlabs_tts_python/tests/test_basic.py b/agents/ten_packages/extension/elevenlabs_tts_python/tests/test_basic.py new file mode 100644 index 00000000..c3755f44 --- /dev/null +++ b/agents/ten_packages/extension/elevenlabs_tts_python/tests/test_basic.py @@ -0,0 +1,36 @@ +# +# Copyright © 2024 Agora +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0, with certain conditions. +# Refer to the "LICENSE" file in the root directory for more information. +# +from pathlib import Path +from ten import ExtensionTester, TenEnvTester, Cmd, CmdResult, StatusCode + + +class ExtensionTesterBasic(ExtensionTester): + def check_hello(self, ten_env: TenEnvTester, result: CmdResult): + statusCode = result.get_status_code() + print("receive hello_world, status:" + str(statusCode)) + + if statusCode == StatusCode.OK: + ten_env.stop_test() + + def on_start(self, ten_env: TenEnvTester) -> None: + new_cmd = Cmd.create("hello_world") + + print("send hello_world") + ten_env.send_cmd( + new_cmd, + lambda ten_env, result: self.check_hello(ten_env, result), + ) + + print("tester on_start_done") + ten_env.on_start_done() + + +def test_basic(): + tester = ExtensionTesterBasic() + tester.add_addon_base_dir(str(Path(__file__).resolve().parent.parent)) + tester.set_test_mode_single("default_async_extension_python") + tester.run() diff --git a/agents/ten_packages/extension/minimax_tts_python/BUILD.gn b/agents/ten_packages/extension/minimax_tts_python/BUILD.gn new file mode 100644 index 00000000..02a50c6e --- /dev/null +++ b/agents/ten_packages/extension/minimax_tts_python/BUILD.gn @@ -0,0 +1,19 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import("//build/feature/ten_package.gni") + +ten_package("minimax_tts_python") { + package_kind = "extension" + + resources = [ + "__init__.py", + "addon.py", + "extension.py", + "manifest.json", + "property.json", + "tests", + ] +} diff --git a/agents/ten_packages/extension/minimax_tts_python/README.md b/agents/ten_packages/extension/minimax_tts_python/README.md new file mode 100644 index 00000000..013a4631 --- /dev/null +++ b/agents/ten_packages/extension/minimax_tts_python/README.md @@ -0,0 +1,29 @@ +# minimax_tts_python + + + +## Features + + + +- xxx feature + +## API + +Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). + + + +## Development + +### Build + + + +### Unit test + + + +## Misc + + diff --git a/agents/ten_packages/extension/minimax_tts_python/__init__.py b/agents/ten_packages/extension/minimax_tts_python/__init__.py index ada52945..72593ab2 100644 --- a/agents/ten_packages/extension/minimax_tts_python/__init__.py +++ b/agents/ten_packages/extension/minimax_tts_python/__init__.py @@ -1,11 +1,6 @@ # -# -# Agora Real Time Engagement -# Created by Tomas Liu/XinHui Li in 2024. -# Copyright (c) 2024 Agora IO. All rights reserved. -# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. # from . import addon -from .log import logger - -logger.info("minimax_tts_python extension loaded") diff --git a/agents/ten_packages/extension/minimax_tts_python/addon.py b/agents/ten_packages/extension/minimax_tts_python/addon.py index 2274dd49..6bdf4ec5 100644 --- a/agents/ten_packages/extension/minimax_tts_python/addon.py +++ b/agents/ten_packages/extension/minimax_tts_python/addon.py @@ -1,22 +1,19 @@ # -# -# Agora Real Time Engagement -# Created by Tomas Liu/XinHui Li in 2024. -# Copyright (c) 2024 Agora IO. All rights reserved. -# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. # from ten import ( Addon, register_addon_as_extension, TenEnv, ) -from .extension import MinimaxTTSExtension -from .log import logger @register_addon_as_extension("minimax_tts_python") class MinimaxTTSExtensionAddon(Addon): def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: - logger.info("on_create_instance") + from .extension import MinimaxTTSExtension + ten_env.log_info("MinimaxTTSExtensionAddon on_create_instance") ten_env.on_create_instance_done(MinimaxTTSExtension(name), context) diff --git a/agents/ten_packages/extension/minimax_tts_python/extension.py b/agents/ten_packages/extension/minimax_tts_python/extension.py index 6a09cd1a..052ce250 100644 --- a/agents/ten_packages/extension/minimax_tts_python/extension.py +++ b/agents/ten_packages/extension/minimax_tts_python/extension.py @@ -1,293 +1,56 @@ # +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. # -# Agora Real Time Engagement -# Created by Tomas Liu/XinHui Li in 2024. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# -import threading -from datetime import datetime -import requests -import json - -from queue import Queue -from typing import Iterator - +import traceback +from ten.data import Data +from ten_ai_base.tts import AsyncTTSBaseExtension +from .minimax_tts import MinimaxTTS, MinimaxTTSConfig from ten import ( - AudioFrame, - AudioFrameDataFmt, - VideoFrame, - Extension, - TenEnv, - Cmd, - StatusCode, - CmdResult, - Data, + AsyncTenEnv, ) -from .log import logger -PROPERTY_API_KEY = "api_key" -PROPERTY_GROUP_ID = "group_id" -PROPERTY_MODEL = "model" -PROPERTY_REQUEST_TIMEOUT_SECONDS = "request_timeout_seconds" -PROPERTY_SAMPLE_RATE = "sample_rate" -PROPERTY_URL = "url" -PROPERTY_VOICE_ID = "voice_id" +class MinimaxTTSExtension(AsyncTTSBaseExtension): + def __init__(self, name: str): + super().__init__(name) + self.client = None -class MinimaxTTSExtension(Extension): - ten_env: TenEnv = None + async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) + ten_env.log_debug("on_init") - api_key: str = "" - dump: bool = False - group_id: str = "" - model: str = "speech-01-turbo" - request_timeout_seconds: int = 10 - sample_rate: int = 32000 - url: str = "https://api.minimax.chat/v1/t2a_v2" - voice_id: str = "male-qn-qingse" + async def on_start(self, ten_env: AsyncTenEnv) -> None: + await super().on_start(ten_env) + ten_env.log_debug("on_start") - thread: threading.Thread = None - queue = Queue() + config = MinimaxTTSConfig.create(ten_env=ten_env) - stopped: bool = False - outdate_ts = datetime.now() - mutex = threading.Lock() + ten_env.log_info(f"config: {config.api_key}, {config.group_id}") - def on_init(self, ten_env: TenEnv) -> None: - logger.info("MinimaxTTSExtension on_init") - self.ten_env = ten_env - ten_env.on_init_done() + if not config.api_key or not config.group_id: + raise ValueError("api_key and group_id are required") - def on_start(self, ten_env: TenEnv) -> None: - logger.info("MinimaxTTSExtension on_start") + self.client = MinimaxTTS(config) - try: - self.api_key = ten_env.get_property_string(PROPERTY_API_KEY) - except Exception as err: - logger.info(f"GetProperty required {PROPERTY_API_KEY} failed, err: {err}") + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + await super().on_stop(ten_env) + ten_env.log_debug("on_stop") - try: - self.group_id = ten_env.get_property_string(PROPERTY_GROUP_ID) - except Exception as err: - logger.info(f"GetProperty required {PROPERTY_GROUP_ID} failed, err: {err}") - return + # TODO: clean up resources - try: - self.model = ten_env.get_property_string(PROPERTY_MODEL) - except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_MODEL} failed, err: {err}") + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) + ten_env.log_debug("on_deinit") + async def on_request_tts(self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool) -> None: try: - self.sample_rate = ten_env.get_property_int(PROPERTY_SAMPLE_RATE) + data = self.client.get(ten_env, input_text) + async for frame in data: + self.send_audio_out(ten_env, frame, sample_rate=self.client.config.sample_rate) except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_SAMPLE_RATE} failed, err: {err}") - - try: - self.request_timeout_seconds = ten_env.get_property_int(PROPERTY_REQUEST_TIMEOUT_SECONDS) - except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_REQUEST_TIMEOUT_SECONDS} failed, err: {err}") - - try: - self.voice_id = ten_env.get_property_string(PROPERTY_VOICE_ID) - except Exception as err: - logger.info(f"GetProperty optional {PROPERTY_VOICE_ID} failed, err: {err}") - - self.thread = threading.Thread(target=self.loop) - self.thread.start() - - ten_env.on_start_done() - - def on_stop(self, ten_env: TenEnv) -> None: - logger.info("MinimaxTTSExtension on_stop") - - self.stopped = True - self._flush() - self.queue.put(None) - if self.thread: - self.thread.join() - self.thread = None - - ten_env.on_stop_done() - - def loop(self) -> None: - while not self.stopped: - entry = self.queue.get() - if entry is None: - return - - try: - ts, text = entry - if self._need_interrupt(ts): - continue - self._call_tts_stream(ts, text) - except Exception as e: - logger.exception(f"Failed to handle entry, err {e}") - - def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: - cmd_name = cmd.get_name() - logger.info("on_cmd name {}".format(cmd_name)) - - if cmd_name == "flush": - self._flush() - - out_cmd = Cmd.create("flush") - ten_env.send_cmd( - out_cmd, lambda ten, result: logger.info( - "send_cmd flush done"), - ) - - cmd_result = CmdResult.create(StatusCode.OK) - ten_env.return_result(cmd_result, cmd) - - def on_data(self, ten_env: TenEnv, data: Data) -> None: - logger.debug("on_data") - - try: - text = data.get_property_string("text") - except Exception as e: - logger.warning(f"on_data get_property_string text error: {e}") - return - - if len(text) == 0: - logger.debug("on_data text is empty, ignored") - return - - logger.info(f"OnData input text: [{text}]") - - self.queue.put((datetime.now(), text)) - - def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: - pass - - def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: - pass - - def _need_interrupt(self, ts: datetime.time) -> bool: - with self.mutex: - return self.outdate_ts > ts - - def _call_tts_stream(self, ts: datetime, text: str) -> Iterator[bytes]: - payload = { - "model": self.model, - "text": text, - "stream": True, - "voice_setting": { - "voice_id": self.voice_id, - "speed": 1.0, - "vol": 1.0, - "pitch": 0 - }, - "pronunciation_dict": { - "tone": [] - }, - "audio_setting": { - "sample_rate": self.sample_rate, - "format": "pcm", - "channel": 1 - } - } - - url = "%s?GroupId=%s" % (self.url, self.group_id) - headers = { - 'accept': 'application/json, text/plain, */*', - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - start_time = datetime.now() - logger.info(f"start request, url: {self.url}, text: {text}") - ttfb = None - try: - with requests.request("POST", url, stream=True, headers=headers, data=json.dumps(payload), timeout=self.request_timeout_seconds) as response: - trace_id = "" - alb_receive_time = "" - - try: - trace_id = response.headers.get("Trace-Id") - except: - logger.warning("get response, no Trace-Id") - try: - alb_receive_time = response.headers.get("alb_receive_time") - except: - logger.warning("get response, no alb_receive_time") - - logger.info(f"get response trace-id: {trace_id}, alb_receive_time: {alb_receive_time}, cost_time {self._duration_in_ms_since(start_time)}ms") - - response.raise_for_status() - - for chunk in (response.raw): - if self._need_interrupt(ts): - logger.warning(f"trace-id: {trace_id}, interrupted") - break - - if not chunk: - continue - if chunk[:5] != b'data:': - logger.debug(f"invalid chunk data {data}") - continue - - logger.debug(f"chunk len {len(chunk)}") - data = json.loads(chunk[5:]) - - if "extra_info" in data: - break - - if "data" not in data: - logger.warning(f"invalid chunk data {data}") - continue - - if "audio" not in data["data"]: - logger.warning(f"invalid chunk data {data}") - continue - - audio = data["data"]['audio'] - if audio is not None and audio != '\n': - decoded_hex = bytes.fromhex(audio) - if len(decoded_hex) > 0: - self._send_audio_out(decoded_hex) - - if not ttfb: - ttfb = self._duration_in_ms_since(start_time) - logger.info(f"trace-id: {trace_id}, ttfb {ttfb}ms") - except Exception as e: - logger.warning(f"unknown err {e}") - finally: - logger.info(f"http loop done, cost_time {self._duration_in_ms_since(start_time)}ms") - - def _send_audio_out(self, audio_data: bytearray) -> None: - self._dump_audio_if_need(audio_data, "out") - - try: - f = AudioFrame.create("pcm_frame") - f.set_sample_rate(self.sample_rate) - f.set_bytes_per_sample(2) - f.set_number_of_channels(1) - f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) - f.set_samples_per_channel(len(audio_data) // 2) - f.alloc_buf(len(audio_data)) - buff = f.lock_buf() - buff[:] = audio_data - f.unlock_buf(buff) - self.ten_env.send_audio_frame(f) - except Exception as e: - logger.exception("error send audio frame, {e}") - - def _flush(self) -> None: - with self.mutex: - self.outdate_ts = datetime.now() - while not self.queue.empty(): - self.queue.get() - - def _dump_audio_if_need(self, buf: bytearray, suffix: str) -> None: - if not self.dump: - return - - with open("{}_{}.pcm".format("minimax_tts", suffix), "ab") as dump_file: - dump_file.write(buf) - - def _duration_in_ms(self, start: datetime, end: datetime) -> int: - return int((end - start).total_seconds() * 1000) + ten_env.log_error(f"on_request_tts failed: {traceback.format_exc()}") - def _duration_in_ms_since(self, start: datetime) -> int: - return self._duration_in_ms(start, datetime.now()) + async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None: + return await super().on_cancel_tts(ten_env) diff --git a/agents/ten_packages/extension/minimax_tts_python/log.py b/agents/ten_packages/extension/minimax_tts_python/log.py deleted file mode 100644 index 2aeb786b..00000000 --- a/agents/ten_packages/extension/minimax_tts_python/log.py +++ /dev/null @@ -1,19 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by Tomas Liu/XinHui Li in 2024. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# -import logging - -logger = logging.getLogger("minimax_tts_python") -logger.setLevel(logging.INFO) - -formatter_str = ("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s") -formatter = logging.Formatter(formatter_str) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/minimax_tts_python/manifest.json b/agents/ten_packages/extension/minimax_tts_python/manifest.json index c910048f..12913aa9 100644 --- a/agents/ten_packages/extension/minimax_tts_python/manifest.json +++ b/agents/ten_packages/extension/minimax_tts_python/manifest.json @@ -1,62 +1,73 @@ { - "type": "extension", - "name": "minimax_tts_python", - "version": "0.1.0", - "dependencies": [ - { - "type": "system", - "name": "ten_runtime_python", - "version": "0.4" + "type": "extension", + "name": "minimax_tts_python", + "version": "0.4.2", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.4.2" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "BUILD.gn", + "**.tent", + "**.py", + "README.md", + "tests/**" + ] + }, + "api": { + "property": { + "api_key": { + "type": "string" + }, + "group_id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "request_timeout_seconds": { + "type": "int64" + }, + "sample_rate": { + "type": "int64" + }, + "url": { + "type": "string" + }, + "voice_id": { + "type": "string" + } + }, + "data_in": [ + { + "name": "text_data", + "property": { + "text": { + "type": "string" + } } + } ], - "api": { - "property": { - "api_key": { - "type": "string" - }, - "group_id": { - "type": "string" - }, - "model": { - "type": "string" - }, - "request_timeout_seconds": { - "type": "int64" - }, - "sample_rate": { - "type": "int64" - }, - "url": { - "type": "string" - }, - "voice_id": { - "type": "string" - } - }, - "data_in": [ - { - "name": "text_data", - "property": { - "text": { - "type": "string" - } - } - } - ], - "cmd_in": [ - { - "name": "flush" - } - ], - "cmd_out": [ - { - "name": "flush" - } - ], - "audio_frame_out": [ - { - "name": "pcm_frame" - } - ] - } + "cmd_in": [ + { + "name": "flush" + } + ], + "cmd_out": [ + { + "name": "flush" + } + ], + "audio_frame_out": [ + { + "name": "pcm_frame" + } + ] + } } \ No newline at end of file diff --git a/agents/ten_packages/extension/minimax_tts_python/minimax_tts.py b/agents/ten_packages/extension/minimax_tts_python/minimax_tts.py new file mode 100644 index 00000000..6f0184c9 --- /dev/null +++ b/agents/ten_packages/extension/minimax_tts_python/minimax_tts.py @@ -0,0 +1,117 @@ +import asyncio +from dataclasses import dataclass +import traceback +import aiohttp +import json +from datetime import datetime +from typing import AsyncIterator + +from ten.async_ten_env import AsyncTenEnv +from ten_ai_base.config import BaseConfig + +@dataclass +class MinimaxTTSConfig(BaseConfig): + api_key: str = "" + model: str = "speech-01-turbo" + voice_id: str = "male-qn-qingse" + sample_rate: int = 32000 + url: str = "https://api.minimax.chat/v1/t2a_v2" + group_id: str = "" + request_timeout_seconds: int = 10 + + +class MinimaxTTS: + def __init__(self, config: MinimaxTTSConfig): + self.config = config + + + async def get(self, ten_env: AsyncTenEnv, text: str) -> AsyncIterator[bytes]: + payload = json.dumps({ + "model": self.config.model, + "text": text, + "stream": True, + "voice_setting": { + "voice_id": self.config.voice_id, + "speed": 1.0, + "vol": 1.0, + "pitch": 0 + }, + "pronunciation_dict": { + "tone": [] + }, + "audio_setting": { + "sample_rate": self.config.sample_rate, + "format": "pcm", + "channel": 1 + } + }) + + url = f"{self.config.url}?GroupId={self.config.group_id}" + headers = { + "accept": "application/json, text/plain, */*", + "Authorization": f"Bearer {self.config.api_key}", + "Content-Type": "application/json" + } + + start_time = datetime.now() + ten_env.log_info(f"Start request, url: {self.config.url}, text: {text}") + ttfb = None + + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, headers=headers, data=payload) as response: + trace_id = "" + alb_receive_time = "" + + try: + trace_id = response.headers.get("Trace-Id") + except: + ten_env.log_warn("get response, no Trace-Id") + try: + alb_receive_time = response.headers.get("alb_receive_time") + except: + ten_env.log_warn("get response, no alb_receive_time") + + ten_env.log_info(f"get response trace-id: {trace_id}, alb_receive_time: {alb_receive_time}, cost_time {self._duration_in_ms_since(start_time)}ms") + + if response.status != 200: + raise Exception(f"Request failed with status {response.status}") + + buffer = b"" + async for chunk in response.content.iter_chunked(1024): # Read in 1024 byte chunks + buffer += chunk + + # Split the buffer into lines based on newline character + while b'\n' in buffer: + line, buffer = buffer.split(b'\n', 1) + + # Process only lines that start with "data:" + if line.startswith(b'data:'): + try: + json_data = json.loads(line[5:].decode('utf-8').strip()) + + # Check for the required keys in the JSON data + if "data" in json_data and "extra_info" not in json_data: + audio = json_data["data"].get("audio") + if audio: + decoded_hex = bytes.fromhex(audio) + yield decoded_hex + except (json.JSONDecodeError, UnicodeDecodeError) as e: + # Handle malformed JSON or decoding errors + ten_env.log_warn(f"Error decoding line: {e}") + continue + if not ttfb: + ttfb = self._duration_in_ms_since(start_time) + ten_env.log_info(f"trace-id: {trace_id}, ttfb {ttfb}ms") + except aiohttp.ClientError as e: + ten_env.log_error(f"Client error occurred: {e}") + except asyncio.TimeoutError: + ten_env.log_error("Request timed out") + finally: + ten_env.log_info(f"http loop done, cost_time {self._duration_in_ms_since(start_time)}ms") + + def _duration_in_ms(self, start: datetime, end: datetime) -> int: + return int((end - start).total_seconds() * 1000) + + def _duration_in_ms_since(self, start: datetime) -> int: + return self._duration_in_ms(start, datetime.now()) diff --git a/agents/ten_packages/extension/minimax_tts_python/requirements.txt b/agents/ten_packages/extension/minimax_tts_python/requirements.txt index ef487e06..ce235718 100644 --- a/agents/ten_packages/extension/minimax_tts_python/requirements.txt +++ b/agents/ten_packages/extension/minimax_tts_python/requirements.txt @@ -1 +1 @@ -requests==2.32.3 \ No newline at end of file +aiohttp \ No newline at end of file diff --git a/agents/ten_packages/extension/minimax_tts_python/tests/test_basic.py b/agents/ten_packages/extension/minimax_tts_python/tests/test_basic.py new file mode 100644 index 00000000..c3755f44 --- /dev/null +++ b/agents/ten_packages/extension/minimax_tts_python/tests/test_basic.py @@ -0,0 +1,36 @@ +# +# Copyright © 2024 Agora +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0, with certain conditions. +# Refer to the "LICENSE" file in the root directory for more information. +# +from pathlib import Path +from ten import ExtensionTester, TenEnvTester, Cmd, CmdResult, StatusCode + + +class ExtensionTesterBasic(ExtensionTester): + def check_hello(self, ten_env: TenEnvTester, result: CmdResult): + statusCode = result.get_status_code() + print("receive hello_world, status:" + str(statusCode)) + + if statusCode == StatusCode.OK: + ten_env.stop_test() + + def on_start(self, ten_env: TenEnvTester) -> None: + new_cmd = Cmd.create("hello_world") + + print("send hello_world") + ten_env.send_cmd( + new_cmd, + lambda ten_env, result: self.check_hello(ten_env, result), + ) + + print("tester on_start_done") + ten_env.on_start_done() + + +def test_basic(): + tester = ExtensionTesterBasic() + tester.add_addon_base_dir(str(Path(__file__).resolve().parent.parent)) + tester.set_test_mode_single("default_async_extension_python") + tester.run() diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/config.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/config.py index 76b9284d..ee7c45b4 100644 --- a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/config.py +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/config.py @@ -28,6 +28,7 @@ def _init(obj, ten_env: TenEnv): # if not ten_env.is_property_exist(field.name): # continue try: + ten_env.log_info(f"init field.name: {field.name}") match field.type: case builtins.str: val = ten_env.get_property_string(field.name) @@ -45,4 +46,4 @@ def _init(obj, ten_env: TenEnv): case _: pass except Exception as e: - pass + ten_env.log_error(f"Error: {e}") diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/const.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/const.py index e9772b1b..bb19e0e8 100644 --- a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/const.py +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/const.py @@ -3,7 +3,15 @@ CMD_PROPERTY_TOOL = "tool" CMD_PROPERTY_RESULT = "tool_result" CMD_CHAT_COMPLETION_CALL = "chat_completion_call" +CMD_IN_FLUSH = "flush" +CMD_OUT_FLUSH = "flush" -DATA_OUTPUT_NAME = "text_data" -DATA_OUTPUT_PROPERTY_TEXT = "text" -DATA_OUTPUT_PROPERTY_END_OF_SEGMENT = "end_of_segment" \ No newline at end of file +DATA_OUT_NAME = "text_data" +DATA_OUT_PROPERTY_TEXT = "text" +DATA_OUT_PROPERTY_END_OF_SEGMENT = "end_of_segment" + +DATA_IN_PROPERTY_TEXT = "text" +DATA_IN_PROPERTY_END_OF_SEGMENT = "end_of_segment" + +DATA_INPUT_NAME = "text_data" +AUDIO_FRAME_OUTPUT_NAME = "pcm_frame" \ No newline at end of file diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/helper.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/helper.py index 9bf012d8..4da14707 100644 --- a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/helper.py +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/helper.py @@ -7,6 +7,8 @@ # import asyncio from collections import deque +from datetime import datetime +import functools from typing import Callable from ten.async_ten_env import AsyncTenEnv @@ -112,4 +114,48 @@ async def flush(self): def __len__(self): """Return the current size of the queue.""" - return len(self._queue) \ No newline at end of file + return len(self._queue) + +def write_pcm_to_file(buffer: bytearray, file_name: str) -> None: + """Helper function to write PCM data to a file.""" + with open(file_name, "ab") as f: # append to file + f.write(buffer) + + +def generate_file_name(prefix: str) -> str: + # Create a timestamp for the file name + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return f"{prefix}_{timestamp}.pcm" + +class PCMWriter: + def __init__(self, prefix: str, write_pcm: bool, buffer_size: int = 1024 * 64): + self.write_pcm = write_pcm + self.buffer = bytearray() + self.buffer_size = buffer_size + self.file_name = generate_file_name(prefix) if write_pcm else None + self.loop = asyncio.get_event_loop() + + async def write(self, data: bytes) -> None: + """Accumulate data into the buffer and write to file when necessary.""" + if not self.write_pcm: + return + + self.buffer.extend(data) + + # Write to file if buffer is full + if len(self.buffer) >= self.buffer_size: + await self._flush() + + async def flush(self) -> None: + """Write any remaining data in the buffer to the file.""" + if self.write_pcm and self.buffer: + await self._flush() + + async def _flush(self) -> None: + """Helper method to write the buffer to the file.""" + if self.file_name: + await self.loop.run_in_executor( + None, + functools.partial(write_pcm_to_file, self.buffer[:], self.file_name), + ) + self.buffer.clear() \ No newline at end of file diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/llm.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/llm.py index 147abe1e..0b74bf62 100644 --- a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/llm.py +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/llm.py @@ -13,7 +13,7 @@ from ten.async_ten_env import AsyncTenEnv from ten.cmd import Cmd from ten.cmd_result import CmdResult, StatusCode -from .const import CMD_PROPERTY_TOOL, CMD_TOOL_REGISTER, DATA_OUTPUT_NAME, DATA_OUTPUT_PROPERTY_END_OF_SEGMENT, DATA_OUTPUT_PROPERTY_TEXT, CMD_CHAT_COMPLETION_CALL +from .const import CMD_PROPERTY_TOOL, CMD_TOOL_REGISTER, DATA_OUT_NAME, DATA_OUT_PROPERTY_END_OF_SEGMENT, DATA_OUT_PROPERTY_TEXT, CMD_CHAT_COMPLETION_CALL from .types import LLMCallCompletionArgs, LLMDataCompletionArgs, LLMToolMetadata from .helper import AsyncQueue import json @@ -108,11 +108,11 @@ async def flush_input_items(self, ten_env: AsyncTenEnv): def send_text_output(self, ten_env: AsyncTenEnv, sentence: str, end_of_segment: bool): try: - output_data = Data.create(DATA_OUTPUT_NAME) + output_data = Data.create(DATA_OUT_NAME) output_data.set_property_string( - DATA_OUTPUT_PROPERTY_TEXT, sentence) + DATA_OUT_PROPERTY_TEXT, sentence) output_data.set_property_bool( - DATA_OUTPUT_PROPERTY_END_OF_SEGMENT, end_of_segment + DATA_OUT_PROPERTY_END_OF_SEGMENT, end_of_segment ) ten_env.send_data(output_data) ten_env.log_info( diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/tts.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/tts.py new file mode 100644 index 00000000..6eaa7c3d --- /dev/null +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/tts.py @@ -0,0 +1,154 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from abc import ABC, abstractmethod +import asyncio +import traceback + +from ten import ( + AsyncExtension, + Data, +) +from ten.async_ten_env import AsyncTenEnv +from ten.audio_frame import AudioFrame, AudioFrameDataFmt +from ten.cmd import Cmd +from ten.cmd_result import CmdResult, StatusCode +from ten_ai_base.const import CMD_IN_FLUSH, CMD_OUT_FLUSH, DATA_IN_PROPERTY_END_OF_SEGMENT, DATA_IN_PROPERTY_TEXT +from ten_ai_base.types import TTSPcmOptions +from .helper import AsyncQueue, PCMWriter, get_property_bool, get_property_string + + +class AsyncTTSBaseExtension(AsyncExtension, ABC): + """ + Base class for implementing a Text-to-Speech Extension. + This class provides a basic implementation for converting text to speech. + It automatically handles the processing of tts requests. + Use begin_send_audio_out, send_audio_out, end_send_audio_out to send the audio data to the output. + Override on_request_tts to implement the TTS logic. + """ + # Create the queue for message processing + + def __init__(self, name: str): + super().__init__(name) + self.queue = AsyncQueue() + self.current_task = None + self.loop_task = None + self.leftover_bytes = b'' + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) + + async def on_start(self, ten_env: AsyncTenEnv) -> None: + await super().on_start(ten_env) + + if self.loop_task is None: + self.loop = asyncio.get_event_loop() + self.loop_task = self.loop.create_task(self._process_queue(ten_env)) + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + await super().on_stop(ten_env) + self.loop_task.cancel() + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) + + async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None: + cmd_name = cmd.get_name() + async_ten_env.log_info(f"on_cmd name: {cmd_name}") + + if cmd_name == CMD_IN_FLUSH: + await self.on_cancel_tts(async_ten_env) + await self.flush_input_items(async_ten_env) + await async_ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH)) + async_ten_env.log_info("on_cmd sent flush") + status_code, detail = StatusCode.OK, "success" + cmd_result = CmdResult.create(status_code) + cmd_result.set_property_string("detail", detail) + async_ten_env.return_result(cmd_result, cmd) + + async def on_data(self, async_ten_env: AsyncTenEnv, data: Data) -> None: + # Get the necessary properties + async_ten_env.log_info(f"on_data name: {data.get_name()}") + input_text = get_property_string(data, DATA_IN_PROPERTY_TEXT) + end_of_segment = get_property_bool(data, DATA_IN_PROPERTY_END_OF_SEGMENT) + + if not input_text: + async_ten_env.log_warn("ignore empty text") + return + + # Start an asynchronous task for handling tts + await self.queue.put([input_text, end_of_segment]) + + async def flush_input_items(self, ten_env: AsyncTenEnv): + """Flushes the self.queue and cancels the current task.""" + # Flush the queue using the new flush method + await self.queue.flush() + + # Cancel the current task if one is running + if self.current_task: + ten_env.log_info("Cancelling the current task during flush.") + self.current_task.cancel() + + def send_audio_out(self, ten_env: AsyncTenEnv, audio_data: bytes, **args: TTSPcmOptions) -> None: + """End sending audio out.""" + sample_rate = args.get("sample_rate", 16000) + bytes_per_sample = args.get("bytes_per_sample", 2) + number_of_channels = args.get("number_of_channels", 1) + try: + # Combine leftover bytes with new audio data + combined_data = self.leftover_bytes + audio_data + + # Check if combined_data length is odd + if len(combined_data) % (bytes_per_sample * number_of_channels) != 0: + # Save the last incomplete frame + valid_length = len(combined_data) - (len(combined_data) % (bytes_per_sample * number_of_channels)) + self.leftover_bytes = combined_data[valid_length:] + combined_data = combined_data[:valid_length] + else: + self.leftover_bytes = b'' + + if combined_data: + f = AudioFrame.create("pcm_frame") + f.set_sample_rate(sample_rate) + f.set_bytes_per_sample(bytes_per_sample) + f.set_number_of_channels(number_of_channels) + f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) + f.set_samples_per_channel(len(combined_data) // (bytes_per_sample * number_of_channels)) + f.alloc_buf(len(combined_data)) + buff = f.lock_buf() + buff[:] = combined_data + f.unlock_buf(buff) + ten_env.send_audio_frame(f) + except Exception as e: + ten_env.log_error(f"error send audio frame, {traceback.format_exc()}") + + @abstractmethod + async def on_request_tts(self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool) -> None: + """ + Called when a new input item is available in the queue. Override this method to implement the TTS request logic. + Use send_audio_out to send the audio data to the output when the audio data is ready. + """ + pass + + @abstractmethod + async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None: + """Called when the TTS request is cancelled.""" + pass + + + async def _process_queue(self, ten_env: AsyncTenEnv): + """Asynchronously process queue items one by one.""" + while True: + # Wait for an item to be available in the queue + [text, end_of_segment] = await self.queue.get() + + try: + self.current_task = asyncio.create_task( + self.on_request_tts(ten_env, text, end_of_segment)) + await self.current_task # Wait for the current task to finish or be cancelled + except asyncio.CancelledError: + ten_env.log_info(f"Task cancelled: {text}") + except Exception as err: + ten_env.log_error(f"Task failed: {text}, err: {traceback.format_exc()}") diff --git a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/types.py b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/types.py index 04fb7f50..ae2dae77 100644 --- a/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/types.py +++ b/agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/types.py @@ -93,4 +93,15 @@ class LLMCallCompletionArgs(TypedDict, total=False): class LLMDataCompletionArgs(TypedDict, total=False): messages: Iterable[LLMChatCompletionMessageParam] - no_tool: bool = False \ No newline at end of file + no_tool: bool = False + + +class TTSPcmOptions(TypedDict, total=False): + sample_rate: int + """The sample rate of the audio data in Hz.""" + + num_channels: int + """The number of audio channels.""" + + bytes_per_sample: int + """The number of bytes per sample.""" \ No newline at end of file