Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/refactor polly tts #449

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions agents/examples/experimental/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@
"secret_key": "${env:AWS_SECRET_ACCESS_KEY}",
"engine": "generative",
"voice": "Ruth",
"sample_rate": "16000",
"sample_rate": 16000,
"lang_code": "en-US"
}
},
Expand Down Expand Up @@ -1439,7 +1439,7 @@
"secret_key": "${env:AWS_SECRET_ACCESS_KEY}",
"engine": "generative",
"voice": "Ruth",
"sample_rate": "16000",
"sample_rate": 16000,
"lang_code": "en-US"
}
},
Expand Down
19 changes: 19 additions & 0 deletions agents/ten_packages/extension/polly_tts/BUILD.gn
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
import("//build/feature/ten_package.gni")

ten_package("polly_tts") {
package_kind = "extension"

resources = [
"__init__.py",
"addon.py",
"extension.py",
"manifest.json",
"property.json",
"tests",
]
}
11 changes: 6 additions & 5 deletions agents/ten_packages/extension/polly_tts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import polly_tts_addon
from .extension import EXTENSION_NAME
from .log import logger

logger.info(f"{EXTENSION_NAME} extension loaded")
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
from . import addon
17 changes: 17 additions & 0 deletions agents/ten_packages/extension/polly_tts/addon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
from ten import (
Addon,
register_addon_as_extension,
TenEnv,
)

@register_addon_as_extension("polly_tts")
class PollyTTSExtensionAddon(Addon):
def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None:
from .extension import PollyTTSExtension
ten_env.log_info("polly tts on_create_instance")
ten_env.on_create_instance_done(PollyTTSExtension(name), context)
58 changes: 57 additions & 1 deletion agents/ten_packages/extension/polly_tts/extension.py
Original file line number Diff line number Diff line change
@@ -1 +1,57 @@
EXTENSION_NAME = "polly_tts"
from ten_ai_base.tts import AsyncTTSBaseExtension
from .polly_tts import PollyTTS, PollyTTSConfig
import traceback
from ten import (
AsyncTenEnv,
)
PROPERTY_REGION = "region" # Optional
PROPERTY_ACCESS_KEY = "access_key" # Optional
PROPERTY_SECRET_KEY = "secret_key" # Optional
PROPERTY_ENGINE = "engine" # Optional
PROPERTY_VOICE = "voice" # Optional
PROPERTY_SAMPLE_RATE = "sample_rate" # Optional
PROPERTY_LANG_CODE = "lang_code" # Optional

class PollyTTSExtension(AsyncTTSBaseExtension):
def __init__(self, name: str):
super().__init__(name)
self.client = None
self.config = None

async def on_init(self, ten_env: AsyncTenEnv) -> None:
await super().on_init(ten_env)
ten_env.log_debug("on_init")

async def on_start(self, ten_env: AsyncTenEnv) -> None:
try:
await super().on_start(ten_env)
ten_env.log_debug("on_start")
self.config = PollyTTSConfig.create(ten_env=ten_env)

if not self.config.access_key or not self.config.secret_key:
raise ValueError("access_key and secret_key are required")

self.client = PollyTTS(self.config, ten_env)
except Exception as err:
ten_env.log_error(f"on_start failed: {traceback.format_exc()}")

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
await super().on_stop(ten_env)
ten_env.log_debug("on_stop")

# TODO: clean up resources

async def on_deinit(self, ten_env: AsyncTenEnv) -> None:
await super().on_deinit(ten_env)
ten_env.log_debug("on_deinit")

async def on_request_tts(self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool) -> None:
try:
data = self.client.text_to_speech_stream(ten_env, input_text)
async for frame in data:
self.send_audio_out(ten_env, frame, sample_rate=self.client.config.sample_rate)
except Exception as err:
ten_env.log_error(f"on_request_tts failed: {traceback.format_exc()}")

async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None:
return await super().on_cancel_tts(ten_env)
12 changes: 0 additions & 12 deletions agents/ten_packages/extension/polly_tts/log.py

This file was deleted.

13 changes: 12 additions & 1 deletion agents/ten_packages/extension/polly_tts/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
"version": "0.4"
}
],
"package": {
"include": [
"manifest.json",
"property.json",
"BUILD.gn",
"**.tent",
"**.py",
"README.md",
"tests/**"
]
},
"api": {
"property": {
"region": {
Expand All @@ -27,7 +38,7 @@
"type": "string"
},
"sample_rate": {
"type": "string"
"type": "int64"
},
"lang_code": {
"type": "string"
Expand Down
94 changes: 94 additions & 0 deletions agents/ten_packages/extension/polly_tts/polly_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from dataclasses import dataclass
import traceback
import json
from typing import AsyncIterator
from ten.async_ten_env import AsyncTenEnv
from ten_ai_base.config import BaseConfig
import boto3
from botocore.exceptions import ClientError
from contextlib import closing

@dataclass
class PollyTTSConfig(BaseConfig):
region: str = "us-east-1"
access_key: str = ""
secret_key: str = ""
engine: str = "generative"
voice: str = "Matthew" # https://docs.aws.amazon.com/polly/latest/dg/available-voices.html
sample_rate: int = 16000
lang_code: str = 'en-US'
bytes_per_sample: int = 2
include_visemes: bool = False
number_of_channels: int = 1
audio_format: str = 'pcm'

class PollyTTS:
def __init__(self, config: PollyTTSConfig, ten_env: AsyncTenEnv) -> None:
"""
:param config: A PollyConfig
"""
ten_env.log_info("startinit polly tts")
self.config = config
if config.access_key and config.secret_key:
self.client = boto3.client(service_name='polly',
region_name=config.region,
aws_access_key_id=config.access_key,
aws_secret_access_key=config.secret_key)
else:
self.client = boto3.client(service_name='polly', region_name=config.region)

self.voice_metadata = None
self.frame_size = int(
int(config.sample_rate)
* self.config.number_of_channels
* self.config.bytes_per_sample
/ 100
)

def _synthesize(self, text, ten_env: AsyncTenEnv):
"""
Synthesizes speech or speech marks from text, using the specified voice.

:param text: The text to synthesize.
:return: The audio stream that contains the synthesized speech and a list
of visemes that are associated with the speech audio.
"""
try:
kwargs = {
"Engine": self.config.engine,
"OutputFormat": self.config.audio_format,
"Text": text,
"VoiceId": self.config.voice,
}
if self.config.lang_code is not None:
kwargs["LanguageCode"] = self.config.lang_code
response = self.client.synthesize_speech(**kwargs)
audio_stream = response["AudioStream"]
visemes = None
if self.config.include_visemes:
kwargs["OutputFormat"] = "json"
kwargs["SpeechMarkTypes"] = ["viseme"]
response = self.client.synthesize_speech(**kwargs)
visemes = [
json.loads(v)
for v in response["AudioStream"].read().decode().split()
if v
]
ten_env.log_debug("Got %s visemes.", len(visemes))
except ClientError:
ten_env.log_error("Couldn't get audio stream.")
raise
else:
return audio_stream, visemes

async def text_to_speech_stream(self, ten_env: AsyncTenEnv, text: str) -> AsyncIterator[bytes]:
inputText = text
if len(inputText) == 0:
ten_env.log_warning("async_polly_handler: empty input detected.")
try:
audio_stream, visemes = self._synthesize(inputText, ten_env)
with closing(audio_stream) as stream:
for chunk in stream.iter_chunks(chunk_size=self.frame_size):
yield chunk
except Exception as e:
ten_env.log_error(traceback.format_exc())
15 changes: 0 additions & 15 deletions agents/ten_packages/extension/polly_tts/polly_tts_addon.py

This file was deleted.

Loading