diff --git a/agents/examples/experimental/property.json b/agents/examples/experimental/property.json index 39dd74b6..1dd89ef1 100644 --- a/agents/examples/experimental/property.json +++ b/agents/examples/experimental/property.json @@ -1273,7 +1273,7 @@ "secret_key": "${env:AWS_SECRET_ACCESS_KEY}", "engine": "generative", "voice": "Ruth", - "sample_rate": "16000", + "sample_rate": 16000, "lang_code": "en-US" } }, @@ -1439,7 +1439,7 @@ "secret_key": "${env:AWS_SECRET_ACCESS_KEY}", "engine": "generative", "voice": "Ruth", - "sample_rate": "16000", + "sample_rate": 16000, "lang_code": "en-US" } }, diff --git a/agents/ten_packages/extension/polly_tts/BUILD.gn b/agents/ten_packages/extension/polly_tts/BUILD.gn new file mode 100644 index 00000000..16a3f9f3 --- /dev/null +++ b/agents/ten_packages/extension/polly_tts/BUILD.gn @@ -0,0 +1,19 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import("//build/feature/ten_package.gni") + +ten_package("polly_tts") { + package_kind = "extension" + + resources = [ + "__init__.py", + "addon.py", + "extension.py", + "manifest.json", + "property.json", + "tests", + ] +} \ No newline at end of file diff --git a/agents/ten_packages/extension/polly_tts/__init__.py b/agents/ten_packages/extension/polly_tts/__init__.py index 50d3bf2c..9e61a472 100644 --- a/agents/ten_packages/extension/polly_tts/__init__.py +++ b/agents/ten_packages/extension/polly_tts/__init__.py @@ -1,5 +1,6 @@ -from . import polly_tts_addon -from .extension import EXTENSION_NAME -from .log import logger - -logger.info(f"{EXTENSION_NAME} extension loaded") +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon \ No newline at end of file diff --git a/agents/ten_packages/extension/polly_tts/addon.py b/agents/ten_packages/extension/polly_tts/addon.py new file mode 100644 index 00000000..1ba59748 --- /dev/null +++ b/agents/ten_packages/extension/polly_tts/addon.py @@ -0,0 +1,17 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten import ( + Addon, + register_addon_as_extension, + TenEnv, +) + +@register_addon_as_extension("polly_tts") +class PollyTTSExtensionAddon(Addon): + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import PollyTTSExtension + ten_env.log_info("polly tts on_create_instance") + ten_env.on_create_instance_done(PollyTTSExtension(name), context) diff --git a/agents/ten_packages/extension/polly_tts/extension.py b/agents/ten_packages/extension/polly_tts/extension.py index 63cb3dab..52b68872 100644 --- a/agents/ten_packages/extension/polly_tts/extension.py +++ b/agents/ten_packages/extension/polly_tts/extension.py @@ -1 +1,57 @@ -EXTENSION_NAME = "polly_tts" +from ten_ai_base.tts import AsyncTTSBaseExtension +from .polly_tts import PollyTTS, PollyTTSConfig +import traceback +from ten import ( + AsyncTenEnv, +) +PROPERTY_REGION = "region" # Optional +PROPERTY_ACCESS_KEY = "access_key" # Optional +PROPERTY_SECRET_KEY = "secret_key" # Optional +PROPERTY_ENGINE = "engine" # Optional +PROPERTY_VOICE = "voice" # Optional +PROPERTY_SAMPLE_RATE = "sample_rate" # Optional +PROPERTY_LANG_CODE = "lang_code" # Optional + +class PollyTTSExtension(AsyncTTSBaseExtension): + def __init__(self, name: str): + super().__init__(name) + self.client = None + self.config = None + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) + ten_env.log_debug("on_init") + + async def on_start(self, ten_env: AsyncTenEnv) -> None: + try: + await super().on_start(ten_env) + ten_env.log_debug("on_start") + self.config = PollyTTSConfig.create(ten_env=ten_env) + + if not self.config.access_key or not self.config.secret_key: + raise ValueError("access_key and secret_key are required") + + self.client = PollyTTS(self.config, ten_env) + except Exception as err: + ten_env.log_error(f"on_start failed: {traceback.format_exc()}") + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + await super().on_stop(ten_env) + ten_env.log_debug("on_stop") + + # TODO: clean up resources + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) + ten_env.log_debug("on_deinit") + + async def on_request_tts(self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool) -> None: + try: + data = self.client.text_to_speech_stream(ten_env, input_text) + async for frame in data: + self.send_audio_out(ten_env, frame, sample_rate=self.client.config.sample_rate) + except Exception as err: + ten_env.log_error(f"on_request_tts failed: {traceback.format_exc()}") + + async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None: + return await super().on_cancel_tts(ten_env) \ No newline at end of file diff --git a/agents/ten_packages/extension/polly_tts/log.py b/agents/ten_packages/extension/polly_tts/log.py deleted file mode 100644 index fad21710..00000000 --- a/agents/ten_packages/extension/polly_tts/log.py +++ /dev/null @@ -1,12 +0,0 @@ -import logging -from .extension import EXTENSION_NAME - -logger = logging.getLogger(EXTENSION_NAME) -logger.setLevel(logging.INFO) - -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s") - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler) diff --git a/agents/ten_packages/extension/polly_tts/manifest.json b/agents/ten_packages/extension/polly_tts/manifest.json index 0be14b5d..30c3d478 100644 --- a/agents/ten_packages/extension/polly_tts/manifest.json +++ b/agents/ten_packages/extension/polly_tts/manifest.json @@ -9,6 +9,17 @@ "version": "0.4" } ], + "package": { + "include": [ + "manifest.json", + "property.json", + "BUILD.gn", + "**.tent", + "**.py", + "README.md", + "tests/**" + ] + }, "api": { "property": { "region": { @@ -27,7 +38,7 @@ "type": "string" }, "sample_rate": { - "type": "string" + "type": "int64" }, "lang_code": { "type": "string" diff --git a/agents/ten_packages/extension/polly_tts/polly_tts.py b/agents/ten_packages/extension/polly_tts/polly_tts.py new file mode 100644 index 00000000..760e3983 --- /dev/null +++ b/agents/ten_packages/extension/polly_tts/polly_tts.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +import traceback +import json +from typing import AsyncIterator +from ten.async_ten_env import AsyncTenEnv +from ten_ai_base.config import BaseConfig +import boto3 +from botocore.exceptions import ClientError +from contextlib import closing + +@dataclass +class PollyTTSConfig(BaseConfig): + region: str = "us-east-1" + access_key: str = "" + secret_key: str = "" + engine: str = "generative" + voice: str = "Matthew" # https://docs.aws.amazon.com/polly/latest/dg/available-voices.html + sample_rate: int = 16000 + lang_code: str = 'en-US' + bytes_per_sample: int = 2 + include_visemes: bool = False + number_of_channels: int = 1 + audio_format: str = 'pcm' + +class PollyTTS: + def __init__(self, config: PollyTTSConfig, ten_env: AsyncTenEnv) -> None: + """ + :param config: A PollyConfig + """ + ten_env.log_info("startinit polly tts") + self.config = config + if config.access_key and config.secret_key: + self.client = boto3.client(service_name='polly', + region_name=config.region, + aws_access_key_id=config.access_key, + aws_secret_access_key=config.secret_key) + else: + self.client = boto3.client(service_name='polly', region_name=config.region) + + self.voice_metadata = None + self.frame_size = int( + int(config.sample_rate) + * self.config.number_of_channels + * self.config.bytes_per_sample + / 100 + ) + + def _synthesize(self, text, ten_env: AsyncTenEnv): + """ + Synthesizes speech or speech marks from text, using the specified voice. + + :param text: The text to synthesize. + :return: The audio stream that contains the synthesized speech and a list + of visemes that are associated with the speech audio. + """ + try: + kwargs = { + "Engine": self.config.engine, + "OutputFormat": self.config.audio_format, + "Text": text, + "VoiceId": self.config.voice, + } + if self.config.lang_code is not None: + kwargs["LanguageCode"] = self.config.lang_code + response = self.client.synthesize_speech(**kwargs) + audio_stream = response["AudioStream"] + visemes = None + if self.config.include_visemes: + kwargs["OutputFormat"] = "json" + kwargs["SpeechMarkTypes"] = ["viseme"] + response = self.client.synthesize_speech(**kwargs) + visemes = [ + json.loads(v) + for v in response["AudioStream"].read().decode().split() + if v + ] + ten_env.log_debug("Got %s visemes.", len(visemes)) + except ClientError: + ten_env.log_error("Couldn't get audio stream.") + raise + else: + return audio_stream, visemes + + async def text_to_speech_stream(self, ten_env: AsyncTenEnv, text: str) -> AsyncIterator[bytes]: + inputText = text + if len(inputText) == 0: + ten_env.log_warning("async_polly_handler: empty input detected.") + try: + audio_stream, visemes = self._synthesize(inputText, ten_env) + with closing(audio_stream) as stream: + for chunk in stream.iter_chunks(chunk_size=self.frame_size): + yield chunk + except Exception as e: + ten_env.log_error(traceback.format_exc()) \ No newline at end of file diff --git a/agents/ten_packages/extension/polly_tts/polly_tts_addon.py b/agents/ten_packages/extension/polly_tts/polly_tts_addon.py deleted file mode 100644 index 6dc74b9b..00000000 --- a/agents/ten_packages/extension/polly_tts/polly_tts_addon.py +++ /dev/null @@ -1,15 +0,0 @@ -from ten import ( - Addon, - register_addon_as_extension, - TenEnv, -) -from .extension import EXTENSION_NAME - - -@register_addon_as_extension(EXTENSION_NAME) -class PollyTTSExtensionAddon(Addon): - def on_create_instance(self, ten: TenEnv, addon_name: str, context) -> None: - from .log import logger - from .polly_tts_extension import PollyTTSExtension - logger.info("on_create_instance") - ten.on_create_instance_done(PollyTTSExtension(addon_name), context) diff --git a/agents/ten_packages/extension/polly_tts/polly_tts_extension.py b/agents/ten_packages/extension/polly_tts/polly_tts_extension.py deleted file mode 100644 index b41ad5fa..00000000 --- a/agents/ten_packages/extension/polly_tts/polly_tts_extension.py +++ /dev/null @@ -1,170 +0,0 @@ -from ten import ( - Extension, - TenEnv, - Cmd, - AudioFrameDataFmt, - AudioFrame, - Data, - StatusCode, - CmdResult, -) - -import queue -import threading -from datetime import datetime -import traceback -from contextlib import closing - -from .log import logger -from .polly_wrapper import PollyWrapper, PollyConfig - -PROPERTY_REGION = "region" # Optional -PROPERTY_ACCESS_KEY = "access_key" # Optional -PROPERTY_SECRET_KEY = "secret_key" # Optional -PROPERTY_ENGINE = "engine" # Optional -PROPERTY_VOICE = "voice" # Optional -PROPERTY_SAMPLE_RATE = "sample_rate" # Optional -PROPERTY_LANG_CODE = "lang_code" # Optional - - -class PollyTTSExtension(Extension): - def __init__(self, name: str): - super().__init__(name) - - self.outdateTs = datetime.now() - self.stopped = False - self.thread = None - self.queue = queue.Queue() - self.frame_size = None - - self.bytes_per_sample = 2 - self.number_of_channels = 1 - - def on_start(self, ten: TenEnv) -> None: - logger.info("PollyTTSExtension on_start") - - polly_config = PollyConfig.default_config() - - for optional_param in [ - PROPERTY_REGION, - PROPERTY_ENGINE, - PROPERTY_VOICE, - PROPERTY_SAMPLE_RATE, - PROPERTY_LANG_CODE, - PROPERTY_ACCESS_KEY, - PROPERTY_SECRET_KEY, - ]: - try: - value = ten.get_property_string(optional_param).strip() - if value: - polly_config.__setattr__(optional_param, value) - except Exception as err: - logger.debug( - f"GetProperty optional {optional_param} failed, err: {err}. Using default value: {polly_config.__getattribute__(optional_param)}" - ) - - self.polly = PollyWrapper(polly_config) - self.frame_size = int( - int(polly_config.sample_rate) - * self.number_of_channels - * self.bytes_per_sample - / 100 - ) - - self.thread = threading.Thread(target=self.async_polly_handler, args=[ten]) - self.thread.start() - ten.on_start_done() - - def on_stop(self, ten: TenEnv) -> None: - logger.info("PollyTTSExtension on_stop") - - self.stopped = True - self.queue.put(None) - self.flush() - self.thread.join() - ten.on_stop_done() - - def need_interrupt(self, ts: datetime.time) -> bool: - return (self.outdateTs - ts).total_seconds() > 1 - - def __get_frame(self, data: bytes) -> AudioFrame: - sample_rate = int(self.polly.config.sample_rate) - - f = AudioFrame.create("pcm_frame") - f.set_sample_rate(sample_rate) - f.set_bytes_per_sample(2) - f.set_number_of_channels(1) - - f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE) - f.set_samples_per_channel(sample_rate // 100) - f.alloc_buf(self.frame_size) - buff = f.lock_buf() - if len(data) < self.frame_size: - buff[:] = bytes(self.frame_size) # fill with 0 - buff[: len(data)] = data - f.unlock_buf(buff) - return f - - def async_polly_handler(self, ten: TenEnv): - while not self.stopped: - value = self.queue.get() - if value is None: - logger.warning("async_polly_handler: exit due to None value got.") - break - inputText, ts = value - if len(inputText) == 0: - logger.warning("async_polly_handler: empty input detected.") - continue - try: - audio_stream, visemes = self.polly.synthesize(inputText) - with closing(audio_stream) as stream: - for chunk in stream.iter_chunks(chunk_size=self.frame_size): - if self.need_interrupt(ts): - logger.debug( - "async_polly_handler: got interrupt cmd, stop sending pcm frame." - ) - break - - f = self.__get_frame(chunk) - ten.send_audio_frame(f) - except Exception as e: - logger.exception(e) - logger.exception(traceback.format_exc()) - - def flush(self): - logger.info("PollyTTSExtension flush") - while not self.queue.empty(): - self.queue.get() - self.queue.put(("", datetime.now())) - - def on_data(self, ten: TenEnv, data: Data) -> None: - logger.info("PollyTTSExtension on_data") - inputText = data.get_property_string("text") - if len(inputText) == 0: - logger.info("ignore empty text") - return - - is_end = data.get_property_bool("end_of_segment") - - logger.info("on data %s %d", inputText, is_end) - self.queue.put((inputText, datetime.now())) - - def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: - logger.info("PollyTTSExtension on_cmd") - cmd_json = cmd.to_json() - logger.info("PollyTTSExtension on_cmd json: %s" + cmd_json) - - cmdName = cmd.get_name() - if cmdName == "flush": - self.outdateTs = datetime.now() - self.flush() - cmd_out = Cmd.create("flush") - ten.send_cmd( - cmd_out, lambda ten, result: print("PollyTTSExtension send_cmd done") - ) - else: - logger.info("unknown cmd %s", cmdName) - - cmd_result = CmdResult.create(StatusCode.OK) - cmd_result.set_property_string("detail", "success") - ten.return_result(cmd_result, cmd) diff --git a/agents/ten_packages/extension/polly_tts/polly_wrapper.py b/agents/ten_packages/extension/polly_tts/polly_wrapper.py deleted file mode 100644 index 0f1759bf..00000000 --- a/agents/ten_packages/extension/polly_tts/polly_wrapper.py +++ /dev/null @@ -1,175 +0,0 @@ -import io -import json -import logging -import boto3 -from typing import Union -from botocore.exceptions import ClientError - -from .log import logger - -# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/polly/client/synthesize_speech.html -class PollyConfig: - def __init__(self, - region: str, - access_key: str, - secret_key: str, - voice: str, - engine: str, # 'standard'|'neural'|'long-form'|'generative' - sample_rate: Union[str, int], - lang_code: None): # only necessary if using a bilingual voice - self.region = region - self.access_key = access_key - self.secret_key = secret_key - - self.voice = voice - self.engine = engine - self.lang_code = lang_code - self.sample_rate = str(sample_rate) - - self.speech_mark_type = 'sentence' # 'sentence'|'ssml'|'viseme'|'word' - self.audio_format = 'pcm' # 'json'|'mp3'|'ogg_vorbis'|'pcm' - self.include_visemes = False - - @classmethod - def default_config(cls): - return cls( - region="us-east-1", - access_key="", - secret_key="", - engine="generative", - voice="Matthew", # https://docs.aws.amazon.com/polly/latest/dg/available-voices.html - sample_rate=16000, - lang_code='en-US' - ) - - -class PollyWrapper: - """Encapsulates Amazon Polly functions.""" - - def __init__(self, config: PollyConfig): - """ - :param config: A PollyConfig - """ - - self.config = config - - if config.access_key and config.secret_key: - logger.info(f"PollyTTS initialized with access key: {config.access_key}") - - self.client = boto3.client(service_name='polly', - region_name=config.region, - aws_access_key_id=config.access_key, - aws_secret_access_key=config.secret_key) - else: - logger.info(f"PollyTTS initialized without access key, using default credentials provider chain.") - self.client = boto3.client(service_name='polly', region_name=config.region) - - self.voice_metadata = None - - - def describe_voices(self): - """ - Gets metadata about available voices. - - :return: The list of voice metadata. - """ - try: - response = self.client.describe_voices() - self.voice_metadata = response["Voices"] - logger.info("Got metadata about %s voices.", len(self.voice_metadata)) - except ClientError: - logger.exception("Couldn't get voice metadata.") - raise - else: - return self.voice_metadata - - - def synthesize(self, text): - """ - Synthesizes speech or speech marks from text, using the specified voice. - - :param text: The text to synthesize. - :return: The audio stream that contains the synthesized speech and a list - of visemes that are associated with the speech audio. - """ - try: - kwargs = { - "Engine": self.config.engine, - "OutputFormat": self.config.audio_format, - "Text": text, - "VoiceId": self.config.voice, - } - if self.config.lang_code is not None: - kwargs["LanguageCode"] = self.config.lang_code - response = self.client.synthesize_speech(**kwargs) - audio_stream = response["AudioStream"] - logger.info("Got audio stream spoken by %s.", self.config.voice) - visemes = None - if self.config.include_visemes: - kwargs["OutputFormat"] = "json" - kwargs["SpeechMarkTypes"] = ["viseme"] - response = self.client.synthesize_speech(**kwargs) - visemes = [ - json.loads(v) - for v in response["AudioStream"].read().decode().split() - if v - ] - logger.info("Got %s visemes.", len(visemes)) - except ClientError: - logger.exception("Couldn't get audio stream.") - raise - else: - return audio_stream, visemes - - def get_voice_engines(self): - """ - Extracts the set of available voice engine types from the full list of - voice metadata. - - :return: The set of voice engine types. - """ - if self.voice_metadata is None: - self.describe_voices() - - engines = set() - for voice in self.voice_metadata: - for engine in voice["SupportedEngines"]: - engines.add(engine) - return engines - - - def get_languages(self, engine): - """ - Extracts the set of available languages for the specified engine from the - full list of voice metadata. - - :param engine: The engine type to filter on. - :return: The set of languages available for the specified engine type. - """ - if self.voice_metadata is None: - self.describe_voices() - - return { - vo["LanguageName"]: vo["LanguageCode"] - for vo in self.voice_metadata - if engine in vo["SupportedEngines"] - } - - - def get_voices(self, engine, language_code): - """ - Extracts the set of voices that are available for the specified engine type - and language from the full list of voice metadata. - - :param engine: The engine type to filter on. - :param language_code: The language to filter on. - :return: The set of voices available for the specified engine type and language. - """ - if self.voice_metadata is None: - self.describe_voices() - - return { - vo["Name"]: vo["Id"] - for vo in self.voice_metadata - if engine in vo["SupportedEngines"] and language_code == vo["LanguageCode"] - } \ No newline at end of file diff --git a/agents/ten_packages/extension/polly_tts/requirements.txt b/agents/ten_packages/extension/polly_tts/requirements.txt index 0a92c2be..6179f113 100644 --- a/agents/ten_packages/extension/polly_tts/requirements.txt +++ b/agents/ten_packages/extension/polly_tts/requirements.txt @@ -1 +1 @@ -boto3==1.34.143 \ No newline at end of file +boto3>=1.26.0 \ No newline at end of file