Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: leverage ten's log for interrupt_detector(go) and message_collector(python) #417

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions agents/ten_packages/extension/interrupt_detector/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ package extension

import (
"fmt"
"log/slog"

"ten_framework/ten"
)
Expand All @@ -24,10 +23,6 @@ const (
cmdNameFlush = "flush"
)

var (
logTag = slog.String("extension", "INTERRUPT_DETECTOR_EXTENSION")
)

type interruptDetectorExtension struct {
ten.DefaultExtension
}
Expand All @@ -47,29 +42,27 @@ func (p *interruptDetectorExtension) OnData(
) {
text, err := data.GetPropertyString(textDataTextField)
if err != nil {
slog.Warn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataTextField, err), logTag)
tenEnv.LogWarn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataTextField, err))
return
}

final, err := data.GetPropertyBool(textDataFinalField)
if err != nil {
slog.Warn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataFinalField, err), logTag)
tenEnv.LogWarn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataFinalField, err))
return
}

slog.Debug(fmt.Sprintf("OnData %s: %s %s: %t", textDataTextField, text, textDataFinalField, final), logTag)
tenEnv.LogDebug(fmt.Sprintf("OnData %s: %s %s: %t", textDataTextField, text, textDataFinalField, final))

if final || len(text) >= 2 {
flushCmd, _ := ten.NewCmd(cmdNameFlush)
tenEnv.SendCmd(flushCmd, nil)

slog.Info(fmt.Sprintf("sent cmd: %s", cmdNameFlush), logTag)
tenEnv.LogInfo(fmt.Sprintf("sent cmd: %s", cmdNameFlush))
}
}

func init() {
slog.Info("interrupt_detector extension init", logTag)

// Register addon
ten.RegisterAddonAsExtension(
"interrupt_detector",
Expand Down
3 changes: 0 additions & 3 deletions agents/ten_packages/extension/message_collector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,3 @@
#
#
from .src import addon
from .src.log import logger

logger.info("message_collector extension loaded")
6 changes: 3 additions & 3 deletions agents/ten_packages/extension/message_collector/src/addon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ class MessageCollectorExtensionAddon(Addon):

def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None:
from .extension import MessageCollectorExtension
from .log import logger
logger.info("MessageCollectorExtensionAddon on_create_instance")
ten_env.on_create_instance_done(MessageCollectorExtension(name), context)
ten_env.log_info("on_create_instance")
ten_env.on_create_instance_done(
MessageCollectorExtension(name), context)
56 changes: 31 additions & 25 deletions agents/ten_packages/extension/message_collector/src/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Data,
)
import asyncio
from .log import logger

MAX_SIZE = 800 # 1 KB limit
OVERHEAD_ESTIMATE = 200 # Estimate for the overhead of metadata in the JSON
Expand All @@ -37,38 +36,40 @@
cached_text_map = {}
MAX_CHUNK_SIZE_BYTES = 1024


def _text_to_base64_chunks(text: str, msg_id: str) -> list:
# Ensure msg_id does not exceed 50 characters
if len(msg_id) > 36:
raise ValueError("msg_id cannot exceed 36 characters.")

# Convert text to bytearray
byte_array = bytearray(text, 'utf-8')

# Encode the bytearray into base64
base64_encoded = base64.b64encode(byte_array).decode('utf-8')

# Initialize list to hold the final chunks
chunks = []

# We'll split the base64 string dynamically based on the final byte size
part_index = 0
total_parts = None # We'll calculate total parts once we know how many chunks we create

# Process the base64-encoded content in chunks
current_position = 0
total_length = len(base64_encoded)

while current_position < total_length:
part_index += 1

# Start guessing the chunk size by limiting the base64 content part
estimated_chunk_size = MAX_CHUNK_SIZE_BYTES # We'll reduce this dynamically
content_chunk = ""
count = 0
while True:
# Create the content part of the chunk
content_chunk = base64_encoded[current_position:current_position + estimated_chunk_size]
content_chunk = base64_encoded[current_position:
current_position + estimated_chunk_size]

# Format the chunk
formatted_chunk = f"{msg_id}|{part_index}|{total_parts if total_parts else '???'}|{content_chunk}"
Expand All @@ -81,11 +82,12 @@ def _text_to_base64_chunks(text: str, msg_id: str) -> list:
estimated_chunk_size -= 100 # Reduce content size gradually
count += 1

logger.debug(f"chunk estimate guess: {count}")
# logger.debug(f"chunk estimate guess: {count}")

# Add the current chunk to the list
chunks.append(formatted_chunk)
current_position += estimated_chunk_size # Move to the next part of the content
# Move to the next part of the content
current_position += estimated_chunk_size

# Now that we know the total number of parts, update the chunks with correct total_parts
total_parts = len(chunks)
Expand All @@ -95,19 +97,21 @@ def _text_to_base64_chunks(text: str, msg_id: str) -> list:

return updated_chunks


class MessageCollectorExtension(Extension):
# Create the queue for message processing
queue = asyncio.Queue()

def on_init(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_init")
ten_env.log_info("on_init")
ten_env.on_init_done()

def on_start(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_start")
ten_env.log_info("on_start")

# TODO: read properties, initialize resources
self.loop = asyncio.new_event_loop()

def start_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
Expand All @@ -118,19 +122,19 @@ def start_loop():
ten_env.on_start_done()

def on_stop(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_stop")
ten_env.log_info("on_stop")

# TODO: clean up resources

ten_env.on_stop_done()

def on_deinit(self, ten_env: TenEnv) -> None:
logger.info("MessageCollectorExtension on_deinit")
ten_env.log_info("on_deinit")
ten_env.on_deinit_done()

def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
logger.info("on_cmd name {}".format(cmd_name))
ten_env.log_info("on_cmd name {}".format(cmd_name))

# TODO: process cmd

Expand All @@ -145,7 +149,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
example:
{"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}}
"""
logger.debug(f"on_data")
# ten_env.log_debug(f"on_data")

text = ""
final = True
Expand All @@ -155,7 +159,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
try:
text = data.get_property_string(TEXT_DATA_TEXT_FIELD)
except Exception as e:
logger.exception(
ten_env.log_error(
f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}"
)

Expand All @@ -170,13 +174,14 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
pass

try:
end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD)
end_of_segment = data.get_property_bool(
TEXT_DATA_END_OF_SEGMENT_FIELD)
except Exception as e:
logger.warning(
ten_env.log_warn(
f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}"
)

logger.debug(
ten_env.log_info(
f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final} {TEXT_DATA_STREAM_ID_FIELD}: {stream_id} {TEXT_DATA_END_OF_SEGMENT_FIELD}: {end_of_segment}"
)

Expand Down Expand Up @@ -207,12 +212,14 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:
}

try:
chunks = _text_to_base64_chunks(json.dumps(base_msg_data), message_id)
chunks = _text_to_base64_chunks(
json.dumps(base_msg_data), message_id)
for chunk in chunks:
asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop)
asyncio.run_coroutine_threadsafe(
self._queue_message(chunk), self.loop)

except Exception as e:
logger.warning(f"on_data new_data error: {e}")
ten_env.log_warn(f"on_data new_data error: {e}")
return

def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
Expand All @@ -223,7 +230,6 @@ def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
# TODO: process image frame
pass


async def _queue_message(self, data: str):
await self.queue.put(data)

Expand All @@ -237,4 +243,4 @@ async def _process_queue(self, ten_env: TenEnv):
ten_data.set_property_buf("data", data.encode())
ten_env.send_data(ten_data)
self.queue.task_done()
await asyncio.sleep(0.04)
await asyncio.sleep(0.04)
22 changes: 0 additions & 22 deletions agents/ten_packages/extension/message_collector/src/log.py

This file was deleted.