From cb95f49ab09410bf86a825a13793f4a6d79a4f9b Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Sun, 17 Nov 2024 00:27:15 +0800 Subject: [PATCH] feat: leverage ten's log (#417) --- .../extension/interrupt_detector/extension.go | 15 ++--- .../extension/message_collector/__init__.py | 3 - .../extension/message_collector/src/addon.py | 6 +- .../message_collector/src/extension.py | 56 ++++++++++--------- .../extension/message_collector/src/log.py | 22 -------- 5 files changed, 38 insertions(+), 64 deletions(-) delete mode 100644 agents/ten_packages/extension/message_collector/src/log.py diff --git a/agents/ten_packages/extension/interrupt_detector/extension.go b/agents/ten_packages/extension/interrupt_detector/extension.go index 4a3d4002..2048edbd 100644 --- a/agents/ten_packages/extension/interrupt_detector/extension.go +++ b/agents/ten_packages/extension/interrupt_detector/extension.go @@ -12,7 +12,6 @@ package extension import ( "fmt" - "log/slog" "ten_framework/ten" ) @@ -24,10 +23,6 @@ const ( cmdNameFlush = "flush" ) -var ( - logTag = slog.String("extension", "INTERRUPT_DETECTOR_EXTENSION") -) - type interruptDetectorExtension struct { ten.DefaultExtension } @@ -47,29 +42,27 @@ func (p *interruptDetectorExtension) OnData( ) { text, err := data.GetPropertyString(textDataTextField) if err != nil { - slog.Warn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataTextField, err), logTag) + tenEnv.LogWarn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataTextField, err)) return } final, err := data.GetPropertyBool(textDataFinalField) if err != nil { - slog.Warn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataFinalField, err), logTag) + tenEnv.LogWarn(fmt.Sprintf("OnData GetProperty %s error: %v", textDataFinalField, err)) return } - slog.Debug(fmt.Sprintf("OnData %s: %s %s: %t", textDataTextField, text, textDataFinalField, final), logTag) + tenEnv.LogDebug(fmt.Sprintf("OnData %s: %s %s: %t", textDataTextField, text, textDataFinalField, final)) if final || len(text) >= 2 { flushCmd, _ := ten.NewCmd(cmdNameFlush) tenEnv.SendCmd(flushCmd, nil) - slog.Info(fmt.Sprintf("sent cmd: %s", cmdNameFlush), logTag) + tenEnv.LogInfo(fmt.Sprintf("sent cmd: %s", cmdNameFlush)) } } func init() { - slog.Info("interrupt_detector extension init", logTag) - // Register addon ten.RegisterAddonAsExtension( "interrupt_detector", diff --git a/agents/ten_packages/extension/message_collector/__init__.py b/agents/ten_packages/extension/message_collector/__init__.py index 46f01a81..645dc801 100644 --- a/agents/ten_packages/extension/message_collector/__init__.py +++ b/agents/ten_packages/extension/message_collector/__init__.py @@ -6,6 +6,3 @@ # # from .src import addon -from .src.log import logger - -logger.info("message_collector extension loaded") diff --git a/agents/ten_packages/extension/message_collector/src/addon.py b/agents/ten_packages/extension/message_collector/src/addon.py index 94fa5800..bab5eb19 100644 --- a/agents/ten_packages/extension/message_collector/src/addon.py +++ b/agents/ten_packages/extension/message_collector/src/addon.py @@ -17,6 +17,6 @@ class MessageCollectorExtensionAddon(Addon): def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: from .extension import MessageCollectorExtension - from .log import logger - logger.info("MessageCollectorExtensionAddon on_create_instance") - ten_env.on_create_instance_done(MessageCollectorExtension(name), context) + ten_env.log_info("on_create_instance") + ten_env.on_create_instance_done( + MessageCollectorExtension(name), context) diff --git a/agents/ten_packages/extension/message_collector/src/extension.py b/agents/ten_packages/extension/message_collector/src/extension.py index f54592b0..a6f8f9ac 100644 --- a/agents/ten_packages/extension/message_collector/src/extension.py +++ b/agents/ten_packages/extension/message_collector/src/extension.py @@ -21,7 +21,6 @@ Data, ) import asyncio -from .log import logger MAX_SIZE = 800 # 1 KB limit OVERHEAD_ESTIMATE = 200 # Estimate for the overhead of metadata in the JSON @@ -37,20 +36,21 @@ cached_text_map = {} MAX_CHUNK_SIZE_BYTES = 1024 + def _text_to_base64_chunks(text: str, msg_id: str) -> list: # Ensure msg_id does not exceed 50 characters if len(msg_id) > 36: raise ValueError("msg_id cannot exceed 36 characters.") - + # Convert text to bytearray byte_array = bytearray(text, 'utf-8') - + # Encode the bytearray into base64 base64_encoded = base64.b64encode(byte_array).decode('utf-8') - + # Initialize list to hold the final chunks chunks = [] - + # We'll split the base64 string dynamically based on the final byte size part_index = 0 total_parts = None # We'll calculate total parts once we know how many chunks we create @@ -58,17 +58,18 @@ def _text_to_base64_chunks(text: str, msg_id: str) -> list: # Process the base64-encoded content in chunks current_position = 0 total_length = len(base64_encoded) - + while current_position < total_length: part_index += 1 - + # Start guessing the chunk size by limiting the base64 content part estimated_chunk_size = MAX_CHUNK_SIZE_BYTES # We'll reduce this dynamically content_chunk = "" count = 0 while True: # Create the content part of the chunk - content_chunk = base64_encoded[current_position:current_position + estimated_chunk_size] + content_chunk = base64_encoded[current_position: + current_position + estimated_chunk_size] # Format the chunk formatted_chunk = f"{msg_id}|{part_index}|{total_parts if total_parts else '???'}|{content_chunk}" @@ -81,11 +82,12 @@ def _text_to_base64_chunks(text: str, msg_id: str) -> list: estimated_chunk_size -= 100 # Reduce content size gradually count += 1 - logger.debug(f"chunk estimate guess: {count}") + # logger.debug(f"chunk estimate guess: {count}") # Add the current chunk to the list chunks.append(formatted_chunk) - current_position += estimated_chunk_size # Move to the next part of the content + # Move to the next part of the content + current_position += estimated_chunk_size # Now that we know the total number of parts, update the chunks with correct total_parts total_parts = len(chunks) @@ -95,19 +97,21 @@ def _text_to_base64_chunks(text: str, msg_id: str) -> list: return updated_chunks + class MessageCollectorExtension(Extension): # Create the queue for message processing queue = asyncio.Queue() def on_init(self, ten_env: TenEnv) -> None: - logger.info("MessageCollectorExtension on_init") + ten_env.log_info("on_init") ten_env.on_init_done() def on_start(self, ten_env: TenEnv) -> None: - logger.info("MessageCollectorExtension on_start") + ten_env.log_info("on_start") # TODO: read properties, initialize resources self.loop = asyncio.new_event_loop() + def start_loop(): asyncio.set_event_loop(self.loop) self.loop.run_forever() @@ -118,19 +122,19 @@ def start_loop(): ten_env.on_start_done() def on_stop(self, ten_env: TenEnv) -> None: - logger.info("MessageCollectorExtension on_stop") + ten_env.log_info("on_stop") # TODO: clean up resources ten_env.on_stop_done() def on_deinit(self, ten_env: TenEnv) -> None: - logger.info("MessageCollectorExtension on_deinit") + ten_env.log_info("on_deinit") ten_env.on_deinit_done() def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None: cmd_name = cmd.get_name() - logger.info("on_cmd name {}".format(cmd_name)) + ten_env.log_info("on_cmd name {}".format(cmd_name)) # TODO: process cmd @@ -145,7 +149,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None: example: {"name": "text_data", "properties": {"text": "hello", "is_final": true, "stream_id": 123, "end_of_segment": true}} """ - logger.debug(f"on_data") + # ten_env.log_debug(f"on_data") text = "" final = True @@ -155,7 +159,7 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None: try: text = data.get_property_string(TEXT_DATA_TEXT_FIELD) except Exception as e: - logger.exception( + ten_env.log_error( f"on_data get_property_string {TEXT_DATA_TEXT_FIELD} error: {e}" ) @@ -170,13 +174,14 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None: pass try: - end_of_segment = data.get_property_bool(TEXT_DATA_END_OF_SEGMENT_FIELD) + end_of_segment = data.get_property_bool( + TEXT_DATA_END_OF_SEGMENT_FIELD) except Exception as e: - logger.warning( + ten_env.log_warn( f"on_data get_property_bool {TEXT_DATA_END_OF_SEGMENT_FIELD} error: {e}" ) - logger.debug( + ten_env.log_info( f"on_data {TEXT_DATA_TEXT_FIELD}: {text} {TEXT_DATA_FINAL_FIELD}: {final} {TEXT_DATA_STREAM_ID_FIELD}: {stream_id} {TEXT_DATA_END_OF_SEGMENT_FIELD}: {end_of_segment}" ) @@ -207,12 +212,14 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None: } try: - chunks = _text_to_base64_chunks(json.dumps(base_msg_data), message_id) + chunks = _text_to_base64_chunks( + json.dumps(base_msg_data), message_id) for chunk in chunks: - asyncio.run_coroutine_threadsafe(self._queue_message(chunk), self.loop) + asyncio.run_coroutine_threadsafe( + self._queue_message(chunk), self.loop) except Exception as e: - logger.warning(f"on_data new_data error: {e}") + ten_env.log_warn(f"on_data new_data error: {e}") return def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None: @@ -223,7 +230,6 @@ def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None: # TODO: process image frame pass - async def _queue_message(self, data: str): await self.queue.put(data) @@ -237,4 +243,4 @@ async def _process_queue(self, ten_env: TenEnv): ten_data.set_property_buf("data", data.encode()) ten_env.send_data(ten_data) self.queue.task_done() - await asyncio.sleep(0.04) \ No newline at end of file + await asyncio.sleep(0.04) diff --git a/agents/ten_packages/extension/message_collector/src/log.py b/agents/ten_packages/extension/message_collector/src/log.py deleted file mode 100644 index ff7a400f..00000000 --- a/agents/ten_packages/extension/message_collector/src/log.py +++ /dev/null @@ -1,22 +0,0 @@ -# -# -# Agora Real Time Engagement -# Created by Wei Hu in 2024-08. -# Copyright (c) 2024 Agora IO. All rights reserved. -# -# -import logging - -logger = logging.getLogger("message_collector") -logger.setLevel(logging.INFO) - -formatter_str = ( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - " - "[%(filename)s:%(lineno)d] - %(message)s" -) -formatter = logging.Formatter(formatter_str) - -console_handler = logging.StreamHandler() -console_handler.setFormatter(formatter) - -logger.addHandler(console_handler)