Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transcribe_asr extension; optimize bedrock_llm extension; fix polly_tts bugs #174

Merged
merged 3 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,9 @@ def converse_stream_worker(start_time, input_text, memory):
first_sentence_sent = False

for event in stream:
if start_time < self.outdate_ts:
logger.info(
f"GetConverseStream recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}"
)
# allow 100ms buffer time, in case interruptor's flush cmd comes just after on_data event
if (start_time + 100_000) < self.outdate_ts:
logger.info(f"GetConverseStream recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}, delta > 100ms")
break

if "contentBlockDelta" in event:
Expand All @@ -278,8 +277,8 @@ def converse_stream_worker(start_time, input_text, memory):
sentence, content, sentence_is_final = parse_sentence(
sentence, content
)
if len(sentence) == 0 or not sentence_is_final:
logger.info(f"sentence {sentence} is empty or not final")
if not sentence or not sentence_is_final:
logger.info(f"sentence [{sentence}] is empty or not final")
break
logger.info(
f"GetConverseStream recv for input text: [{input_text}] got sentence: [{sentence}]"
Expand Down Expand Up @@ -313,7 +312,10 @@ def converse_stream_worker(start_time, input_text, memory):

if len(full_content.strip()):
# remember response as assistant content in memory
memory.append(
if memory and memory[-1]['role'] == 'assistant':
memory[-1]['content'].append({"text": full_content})
else:
memory.append(
{"role": "assistant", "content": [{"text": full_content}]}
)
else:
Expand Down
4 changes: 2 additions & 2 deletions agents/addon/extension/polly_tts/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"type": "string"
},
"sample_rate": {
"type": "int64"
"type": "string"
},
"lang_code": {
"type": "string"
Expand Down Expand Up @@ -60,4 +60,4 @@
}
]
}
}
}
6 changes: 0 additions & 6 deletions agents/addon/extension/polly_tts/polly_tts_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ def __init__(self, name: str):
self.bytes_per_sample = 2
self.number_of_channels = 1

def on_init(
self, rte: RteEnv, manifest: MetadataInfo, property: MetadataInfo
) -> None:
logger.info("PollyTTSExtension on_init")
rte.on_init_done(manifest, property)

def on_start(self, rte: RteEnv) -> None:
logger.info("PollyTTSExtension on_start")

Expand Down
11 changes: 11 additions & 0 deletions agents/addon/extension/transcribe_asr_python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Amazon Transcribe ASR Extension

### Configurations

You can config this extension by providing following environments:

| Env | Required | Default | Notes |
| -- | -- | -- | -- |
| AWS_REGION | No | us-east-1 | The Region of Amazon Transcribe service you want to use. |
| AWS_ACCESS_KEY_ID | No | - | Access Key of your IAM User, make sure you've set proper permissions to [start stream transcription](https://docs.aws.amazon.com/transcribe/latest/APIReference/API_streaming_StartStreamTranscription.html). Will use default credentials provider if not provided. Check [document](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). |
| AWS_SECRET_ACCESS_KEY | No | - | Secret Key of your IAM User. Will use default credentials provider if not provided. Check [document](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). |
5 changes: 5 additions & 0 deletions agents/addon/extension/transcribe_asr_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import transcribe_asr_addon
from .extension import EXTENSION_NAME
from .log import logger

logger.info(f"{EXTENSION_NAME} extension loaded")
1 change: 1 addition & 0 deletions agents/addon/extension/transcribe_asr_python/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EXTENSION_NAME = "transcribe_asr"
14 changes: 14 additions & 0 deletions agents/addon/extension/transcribe_asr_python/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging
from .extension import EXTENSION_NAME

logger = logging.getLogger(EXTENSION_NAME)
logger.setLevel(logging.INFO)

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s"
)

console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)

logger.addHandler(console_handler)
76 changes: 76 additions & 0 deletions agents/addon/extension/transcribe_asr_python/manifest.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
{
"type": "extension",
"name": "transcribe_asr_python",
"version": "0.1.0",
"language": "python",
"dependencies": [
{
"type": "system",
"name": "rte_runtime_python",
"version": "0.4.0"
}
],
"api": {
"property": {
"region": {
"type": "string"
},
"access_key": {
"type": "string"
},
"secret_key": {
"type": "string"
},
"sample_rate": {
"type": "string"
},
"lang_code": {
"type": "string"
}
},
"pcm_frame_in": [
{
"name": "pcm_frame"
}
],
"cmd_in": [
{
"name": "on_user_joined"
},
{
"name": "on_user_left"
},
{
"name": "on_connection_failure"
}
],
"data_out": [
{
"name": "text_data",
"property": {
"time": {
"type": "int64"
},
"duration_ms": {
"type": "int64"
},
"language": {
"type": "string"
},
"text": {
"type": "string"
},
"is_final": {
"type": "bool"
},
"stream_id": {
"type": "uint32"
},
"end_of_segment": {
"type": "bool"
}
}
}
]
}
}
1 change: 1 addition & 0 deletions agents/addon/extension/transcribe_asr_python/property.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
amazon-transcribe==0.6.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from rte import (
Addon,
register_addon_as_extension,
RteEnv,
)
from .extension import EXTENSION_NAME
from .log import logger
from .transcribe_asr_extension import TranscribeAsrExtension


@register_addon_as_extension(EXTENSION_NAME)
class TranscribeAsrExtensionAddon(Addon):
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None:
logger.info("on_create_instance")
rte.on_create_instance_done(TranscribeAsrExtension(addon_name), context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from rte import (
Extension,
RteEnv,
Cmd,
PcmFrame,
StatusCode,
CmdResult,
)

import asyncio
import threading

from .log import logger
from .transcribe_wrapper import AsyncTranscribeWrapper, TranscribeConfig

PROPERTY_REGION = "region" # Optional
PROPERTY_ACCESS_KEY = "access_key" # Optional
PROPERTY_SECRET_KEY = "secret_key" # Optional
PROPERTY_SAMPLE_RATE = 'sample_rate'# Optional
PROPERTY_LANG_CODE = 'lang_code' # Optional


class TranscribeAsrExtension(Extension):
def __init__(self, name: str):
super().__init__(name)

self.stopped = False
self.queue = asyncio.Queue(maxsize=3000) # about 3000 * 10ms = 30s input
self.transcribe = None
self.thread = None

self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

def on_start(self, rte: RteEnv) -> None:
logger.info("TranscribeAsrExtension on_start")

transcribe_config = TranscribeConfig.default_config()

for optional_param in [PROPERTY_REGION, PROPERTY_SAMPLE_RATE, PROPERTY_LANG_CODE,
PROPERTY_ACCESS_KEY, PROPERTY_SECRET_KEY]:
try:
value = rte.get_property_string(optional_param).strip()
if value:
transcribe_config.__setattr__(optional_param, value)
except Exception as err:
logger.debug(f"GetProperty optional {optional_param} failed, err: {err}. Using default value: {transcribe_config.__getattribute__(optional_param)}")

self.transcribe = AsyncTranscribeWrapper(transcribe_config, self.queue, rte, self.loop)

logger.info("Starting async_transcribe_wrapper thread")
self.thread = threading.Thread(target=self.transcribe.run, args=[])
self.thread.start()

rte.on_start_done()

def put_pcm_frame(self, pcm_frame: PcmFrame) -> None:
try:
asyncio.run_coroutine_threadsafe(self.queue.put(pcm_frame), self.loop).result(timeout=0.1)
except asyncio.QueueFull:
logger.exception("Queue is full, dropping frame")
except Exception as e:
logger.exception(f"Error putting frame in queue: {e}")

def on_pcm_frame(self, rte: RteEnv, pcm_frame: PcmFrame) -> None:
self.put_pcm_frame(pcm_frame=pcm_frame)

def on_stop(self, rte: RteEnv) -> None:
logger.info("TranscribeAsrExtension on_stop")

# put an empty frame to stop transcribe_wrapper
self.put_pcm_frame(None)
self.stopped = True
self.thread.join()
self.loop.stop()
self.loop.close()

rte.on_stop_done()

def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None:
logger.info("TranscribeAsrExtension on_cmd")
cmd_json = cmd.to_json()
logger.info("TranscribeAsrExtension on_cmd json: " + cmd_json)

cmdName = cmd.get_name()
logger.info("got cmd %s" % cmdName)

cmd_result = CmdResult.create(StatusCode.OK)
cmd_result.set_property_string("detail", "success")
rte.return_result(cmd_result, cmd)
29 changes: 29 additions & 0 deletions agents/addon/extension/transcribe_asr_python/transcribe_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Union

class TranscribeConfig:
def __init__(self,
region: str,
access_key: str,
secret_key: str,
sample_rate: Union[str, int],
lang_code: str):
self.region = region
self.access_key = access_key
self.secret_key = secret_key

self.lang_code = lang_code
self.sample_rate = int(sample_rate)

self.media_encoding = 'pcm'
self.bytes_per_sample = 2,
self.channel_nums = 1

@classmethod
def default_config(cls):
return cls(
region="us-east-1",
access_key="",
secret_key="",
sample_rate=16000,
lang_code='en-US'
)
Loading