Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use message_collector to replace chat_transcriber #248

Merged
merged 1 commit into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 40 additions & 200 deletions agents/property.json

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions agents/ten_packages/extension/message_collector/BUILD.gn
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2022-11.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import("//build/feature/ten_package.gni")

ten_package("message_collector") {
package_kind = "extension"

resources = [
"__init__.py",
"manifest.json",
"property.json",
"src/__init__.py",
"src/addon.py",
"src/extension.py",
"src/log.py",
]
}
29 changes: 29 additions & 0 deletions agents/ten_packages/extension/message_collector/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# message_collector

<!-- brief introduction for the extension -->

## Features

<!-- main features introduction -->

- xxx feature

## API

Refer to `api` definition in [manifest.json] and default values in [property.json](property.json).

<!-- Additional API.md can be referred to if extra introduction needed -->

## Development

### Build

<!-- build dependencies and steps -->

### Unit test

<!-- how to do unit test for the extension -->

## Misc

<!-- others if applicable -->
11 changes: 11 additions & 0 deletions agents/ten_packages/extension/message_collector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from .src import addon
from .src.log import logger

logger.info("message_collector extension loaded")
51 changes: 51 additions & 0 deletions agents/ten_packages/extension/message_collector/manifest.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"type": "extension",
"name": "message_collector",
"version": "0.1.0",
"dependencies": [
{
"type": "system",
"name": "ten_runtime_python",
"version": "0.1.0"
}
],
"package": {
"include": [
"manifest.json",
"property.json",
"BUILD.gn",
"**.tent",
"**.py",
"src/**.tent",
"src/**.py",
"README.md"
]
},
"api": {
"property": {},
"data_in": [
{
"name": "text_data",
"property": {
"text": {
"type": "string"
},
"is_final": {
"type": "bool"
},
"stream_id": {
"type": "uint32"
},
"end_of_segment": {
"type": "bool"
}
}
}
],
"data_out": [
{
"name": "data"
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Empty file.
22 changes: 22 additions & 0 deletions agents/ten_packages/extension/message_collector/src/addon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from ten import (
Addon,
register_addon_as_extension,
TenEnv,
)
from .extension import MessageCollectorExtension
from .log import logger


@register_addon_as_extension("message_collector")
class MessageCollectorExtensionAddon(Addon):

def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None:
logger.info("MessageCollectorExtensionAddon on_create_instance")
ten_env.on_create_instance_done(MessageCollectorExtension(name), context)
150 changes: 150 additions & 0 deletions agents/ten_packages/extension/message_collector/src/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import json
import time
from ten import (
AudioFrame,
VideoFrame,
Extension,
TenEnv,
Cmd,
StatusCode,
CmdResult,
Data,
)
from .log import logger



CMD_NAME_FLUSH = "flush"

TEXT_DATA_TEXT_FIELD = "text"
TEXT_DATA_FINAL_FIELD = "is_final"
TEXT_DATA_STREAM_ID_FIELD = "stream_id"
TEXT_DATA_END_OF_SEGMENT_FIELD = "end_of_segment"

# record the cached text data for each stream id
cached_text_map = {}


class MessageCollectorExtension(Extension):
def on_init(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_init")
ten_env.on_init_done()

def on_start(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_start")

# TODO: read properties, initialize resources

ten_env.on_start_done()

def on_stop(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_stop")

# TODO: clean up resources

ten_env.on_stop_done()

def on_deinit(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_deinit")
ten_env.on_deinit_done()

def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
logger.info("on_cmd name {}".format(cmd_name))

# TODO: process cmd

cmd_result = CmdResult.create(StatusCode.OK)
ten_env.return_result(cmd_result, cmd)

def on_data(self, ten_env: TenEnv, data: Data) -> None:
"""
on_data receives data from ten graph.
current suppotend data:
- name: text_data
example:
{"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}}
"""
logger.info(f"on_data")

text = ""
final = True
stream_id = 0
end_of_segment = False

try:
text = data.get_property_string(TEXT_DATA_TEXT_FIELD)
except Exception as e:
logger.exception(
f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}"
)

try:
final = data.get_property_bool(TEXT_DATA_FINAL_FIELD)
except Exception as e:
logger.exception(
f"on_data get_property_bool {TEXT_DATA_FINAL_FIELD} error: {e}"
)

try:
stream_id = data.get_property_int(TEXT_DATA_STREAM_ID_FIELD)
except Exception as e:
logger.exception(
f"on_data get_property_int {TEXT_DATA_STREAM_ID_FIELD} error: {e}"
)

try:
end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD)
except Exception as e:
logger.exception(
f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}"
)

logger.debug(
f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final} {TEXT_DATA_STREAM_ID_FIELD}: {stream_id} {TEXT_DATA_END_OF_SEGMENT_FIELD}: {end_of_segment}"
)

# We cache all final text data and append the non-final text data to the cached data
# until the end of the segment.
if end_of_segment:
if stream_id in cached_text_map:
plutoless marked this conversation as resolved.
Show resolved Hide resolved
text = cached_text_map[stream_id] + text
del cached_text_map[stream_id]
else:
if final:
if stream_id in cached_text_map:
text = cached_text_map[stream_id] + text

cached_text_map[stream_id] = text

msg_data = json.dumps({
"text": text,
"is_final": end_of_segment,
"stream_id": stream_id,
"data_type": "transcribe",
"text_ts": int(time.time() * 1000), # Convert to milliseconds
})

try:
# convert the origin text data to the protobuf data and send it to the graph.
ten_data = Data.create("data")
ten_data.set_property_buf("data", msg_data.encode())
ten_env.send_data(ten_data)
except Exception as e:
logger.warning(f"on_data new_data error: {e}")
return

def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
# TODO: process pcm frame
pass

def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
# TODO: process image frame
pass
22 changes: 22 additions & 0 deletions agents/ten_packages/extension/message_collector/src/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import logging

logger = logging.getLogger("message_collector")
logger.setLevel(logging.INFO)

formatter_str = (
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - "
"[%(filename)s:%(lineno)d] - %(message)s"
)
formatter = logging.Formatter(formatter_str)

console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)

logger.addHandler(console_handler)
Loading