-
Notifications
You must be signed in to change notification settings - Fork 366
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: use message_collector to replace chat_transcriber (#248)
- add a new message_collector - replace chat_transcriber and take out cmd_conversion - UI adapt
- Loading branch information
Showing
11 changed files
with
389 additions
and
236 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# | ||
# | ||
# Agora Real Time Engagement | ||
# Created by Wei Hu in 2022-11. | ||
# Copyright (c) 2024 Agora IO. All rights reserved. | ||
# | ||
# | ||
import("//build/feature/ten_package.gni") | ||
|
||
ten_package("message_collector") { | ||
package_kind = "extension" | ||
|
||
resources = [ | ||
"__init__.py", | ||
"manifest.json", | ||
"property.json", | ||
"src/__init__.py", | ||
"src/addon.py", | ||
"src/extension.py", | ||
"src/log.py", | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# message_collector | ||
|
||
<!-- brief introduction for the extension --> | ||
|
||
## Features | ||
|
||
<!-- main features introduction --> | ||
|
||
- xxx feature | ||
|
||
## API | ||
|
||
Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). | ||
|
||
<!-- Additional API.md can be referred to if extra introduction needed --> | ||
|
||
## Development | ||
|
||
### Build | ||
|
||
<!-- build dependencies and steps --> | ||
|
||
### Unit test | ||
|
||
<!-- how to do unit test for the extension --> | ||
|
||
## Misc | ||
|
||
<!-- others if applicable --> |
11 changes: 11 additions & 0 deletions
11
agents/ten_packages/extension/message_collector/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# | ||
# | ||
# Agora Real Time Engagement | ||
# Created by Wei Hu in 2024-08. | ||
# Copyright (c) 2024 Agora IO. All rights reserved. | ||
# | ||
# | ||
from .src import addon | ||
from .src.log import logger | ||
|
||
logger.info("message_collector extension loaded") |
51 changes: 51 additions & 0 deletions
51
agents/ten_packages/extension/message_collector/manifest.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
{ | ||
"type": "extension", | ||
"name": "message_collector", | ||
"version": "0.1.0", | ||
"dependencies": [ | ||
{ | ||
"type": "system", | ||
"name": "ten_runtime_python", | ||
"version": "0.1.0" | ||
} | ||
], | ||
"package": { | ||
"include": [ | ||
"manifest.json", | ||
"property.json", | ||
"BUILD.gn", | ||
"**.tent", | ||
"**.py", | ||
"src/**.tent", | ||
"src/**.py", | ||
"README.md" | ||
] | ||
}, | ||
"api": { | ||
"property": {}, | ||
"data_in": [ | ||
{ | ||
"name": "text_data", | ||
"property": { | ||
"text": { | ||
"type": "string" | ||
}, | ||
"is_final": { | ||
"type": "bool" | ||
}, | ||
"stream_id": { | ||
"type": "uint32" | ||
}, | ||
"end_of_segment": { | ||
"type": "bool" | ||
} | ||
} | ||
} | ||
], | ||
"data_out": [ | ||
{ | ||
"name": "data" | ||
} | ||
] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} |
Empty file.
22 changes: 22 additions & 0 deletions
22
agents/ten_packages/extension/message_collector/src/addon.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# | ||
# | ||
# Agora Real Time Engagement | ||
# Created by Wei Hu in 2024-08. | ||
# Copyright (c) 2024 Agora IO. All rights reserved. | ||
# | ||
# | ||
from ten import ( | ||
Addon, | ||
register_addon_as_extension, | ||
TenEnv, | ||
) | ||
from .extension import MessageCollectorExtension | ||
from .log import logger | ||
|
||
|
||
@register_addon_as_extension("message_collector") | ||
class MessageCollectorExtensionAddon(Addon): | ||
|
||
def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: | ||
logger.info("MessageCollectorExtensionAddon on_create_instance") | ||
ten_env.on_create_instance_done(MessageCollectorExtension(name), context) |
150 changes: 150 additions & 0 deletions
150
agents/ten_packages/extension/message_collector/src/extension.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# | ||
# | ||
# Agora Real Time Engagement | ||
# Created by Wei Hu in 2024-08. | ||
# Copyright (c) 2024 Agora IO. All rights reserved. | ||
# | ||
# | ||
import json | ||
import time | ||
from ten import ( | ||
AudioFrame, | ||
VideoFrame, | ||
Extension, | ||
TenEnv, | ||
Cmd, | ||
StatusCode, | ||
CmdResult, | ||
Data, | ||
) | ||
from .log import logger | ||
|
||
|
||
|
||
CMD_NAME_FLUSH = "flush" | ||
|
||
TEXT_DATA_TEXT_FIELD = "text" | ||
TEXT_DATA_FINAL_FIELD = "is_final" | ||
TEXT_DATA_STREAM_ID_FIELD = "stream_id" | ||
TEXT_DATA_END_OF_SEGMENT_FIELD = "end_of_segment" | ||
|
||
# record the cached text data for each stream id | ||
cached_text_map = {} | ||
|
||
|
||
class MessageCollectorExtension(Extension): | ||
def on_init(self, ten_env: TenEnv) -> None: | ||
logger.info("MessageCollectorExtension on_init") | ||
ten_env.on_init_done() | ||
|
||
def on_start(self, ten_env: TenEnv) -> None: | ||
logger.info("MessageCollectorExtension on_start") | ||
|
||
# TODO: read properties, initialize resources | ||
|
||
ten_env.on_start_done() | ||
|
||
def on_stop(self, ten_env: TenEnv) -> None: | ||
logger.info("MessageCollectorExtension on_stop") | ||
|
||
# TODO: clean up resources | ||
|
||
ten_env.on_stop_done() | ||
|
||
def on_deinit(self, ten_env: TenEnv) -> None: | ||
logger.info("MessageCollectorExtension on_deinit") | ||
ten_env.on_deinit_done() | ||
|
||
def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: | ||
cmd_name = cmd.get_name() | ||
logger.info("on_cmd name {}".format(cmd_name)) | ||
|
||
# TODO: process cmd | ||
|
||
cmd_result = CmdResult.create(StatusCode.OK) | ||
ten_env.return_result(cmd_result, cmd) | ||
|
||
def on_data(self, ten_env: TenEnv, data: Data) -> None: | ||
""" | ||
on_data receives data from ten graph. | ||
current suppotend data: | ||
- name: text_data | ||
example: | ||
{"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}} | ||
""" | ||
logger.info(f"on_data") | ||
|
||
text = "" | ||
final = True | ||
stream_id = 0 | ||
end_of_segment = False | ||
|
||
try: | ||
text = data.get_property_string(TEXT_DATA_TEXT_FIELD) | ||
except Exception as e: | ||
logger.exception( | ||
f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}" | ||
) | ||
|
||
try: | ||
final = data.get_property_bool(TEXT_DATA_FINAL_FIELD) | ||
except Exception as e: | ||
logger.exception( | ||
f"on_data get_property_bool {TEXT_DATA_FINAL_FIELD} error: {e}" | ||
) | ||
|
||
try: | ||
stream_id = data.get_property_int(TEXT_DATA_STREAM_ID_FIELD) | ||
except Exception as e: | ||
logger.exception( | ||
f"on_data get_property_int {TEXT_DATA_STREAM_ID_FIELD} error: {e}" | ||
) | ||
|
||
try: | ||
end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD) | ||
except Exception as e: | ||
logger.exception( | ||
f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}" | ||
) | ||
|
||
logger.debug( | ||
f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final} {TEXT_DATA_STREAM_ID_FIELD}: {stream_id} {TEXT_DATA_END_OF_SEGMENT_FIELD}: {end_of_segment}" | ||
) | ||
|
||
# We cache all final text data and append the non-final text data to the cached data | ||
# until the end of the segment. | ||
if end_of_segment: | ||
if stream_id in cached_text_map: | ||
text = cached_text_map[stream_id] + text | ||
del cached_text_map[stream_id] | ||
else: | ||
if final: | ||
if stream_id in cached_text_map: | ||
text = cached_text_map[stream_id] + text | ||
|
||
cached_text_map[stream_id] = text | ||
|
||
msg_data = json.dumps({ | ||
"text": text, | ||
"is_final": end_of_segment, | ||
"stream_id": stream_id, | ||
"data_type": "transcribe", | ||
"text_ts": int(time.time() * 1000), # Convert to milliseconds | ||
}) | ||
|
||
try: | ||
# convert the origin text data to the protobuf data and send it to the graph. | ||
ten_data = Data.create("data") | ||
ten_data.set_property_buf("data", msg_data.encode()) | ||
ten_env.send_data(ten_data) | ||
except Exception as e: | ||
logger.warning(f"on_data new_data error: {e}") | ||
return | ||
|
||
def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: | ||
# TODO: process pcm frame | ||
pass | ||
|
||
def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: | ||
# TODO: process image frame | ||
pass |
22 changes: 22 additions & 0 deletions
22
agents/ten_packages/extension/message_collector/src/log.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# | ||
# | ||
# Agora Real Time Engagement | ||
# Created by Wei Hu in 2024-08. | ||
# Copyright (c) 2024 Agora IO. All rights reserved. | ||
# | ||
# | ||
import logging | ||
|
||
logger = logging.getLogger("message_collector") | ||
logger.setLevel(logging.INFO) | ||
|
||
formatter_str = ( | ||
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - " | ||
"[%(filename)s:%(lineno)d] - %(message)s" | ||
) | ||
formatter = logging.Formatter(formatter_str) | ||
|
||
console_handler = logging.StreamHandler() | ||
console_handler.setFormatter(formatter) | ||
|
||
logger.addHandler(console_handler) |
Oops, something went wrong.