diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 46c8d4e0a4d3..e96dc39fa7bc 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -1,10 +1,9 @@ import copy from typing import Dict, List -from termcolor import colored - from autogen import ConversableAgent +from ....formatting_utils import colored from .transforms import MessageTransform @@ -43,12 +42,14 @@ class TransformMessages: ``` """ - def __init__(self, *, transforms: List[MessageTransform] = []): + def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True): """ Args: transforms: A list of message transformations to apply. + verbose: Whether to print logs of each transformation or not. """ self._transforms = transforms + self._verbose = verbose def add_to_agent(self, agent: ConversableAgent): """Adds the message transformations capability to the specified ConversableAgent. @@ -61,31 +62,26 @@ def add_to_agent(self, agent: ConversableAgent): agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages) def _transform_messages(self, messages: List[Dict]) -> List[Dict]: - temp_messages = copy.deepcopy(messages) + post_transform_messages = copy.deepcopy(messages) system_message = None if messages[0]["role"] == "system": system_message = copy.deepcopy(messages[0]) - temp_messages.pop(0) + post_transform_messages.pop(0) for transform in self._transforms: - temp_messages = transform.apply_transform(temp_messages) - - if system_message: - temp_messages.insert(0, system_message) - - self._print_stats(messages, temp_messages) + # deepcopy in case pre_transform_messages will later be used for logs printing + pre_transform_messages = ( + copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages + ) + post_transform_messages = transform.apply_transform(pre_transform_messages) - return temp_messages + if self._verbose: + logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages) + if had_effect: + print(colored(logs_str, "yellow")) - def _print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): - pre_transform_messages_len = len(pre_transform_messages) - post_transform_messages_len = len(post_transform_messages) + if system_message: + post_transform_messages.insert(0, system_message) - if pre_transform_messages_len < post_transform_messages_len: - print( - colored( - f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.", - "yellow", - ) - ) + return post_transform_messages diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index cc4faace3f18..6dc1d59fe9c7 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -1,5 +1,6 @@ +import copy import sys -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union import tiktoken from termcolor import colored @@ -25,6 +26,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: """ ... + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + """Creates the string including the logs of the transformation + + Alongside the string, it returns a boolean indicating whether the transformation had an effect or not. + + Args: + pre_transform_messages: A list of dictionaries representing messages before the transformation. + post_transform_messages: A list of dictionaries representig messages after the transformation. + + Returns: + A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not. + """ + ... + class MessageHistoryLimiter: """Limits the number of messages considered by an agent for response generation. @@ -60,6 +75,18 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return messages[-self._max_messages :] + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + pre_transform_messages_len = len(pre_transform_messages) + post_transform_messages_len = len(post_transform_messages) + + if post_transform_messages_len < pre_transform_messages_len: + logs_str = ( + f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " + f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}." + ) + return logs_str, True + return "No messages were removed.", False + def _validate_max_messages(self, max_messages: Optional[int]): if max_messages is not None and max_messages < 1: raise ValueError("max_messages must be None or greater than 1") @@ -121,15 +148,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: assert self._max_tokens_per_message is not None assert self._max_tokens is not None - temp_messages = messages.copy() + temp_messages = copy.deepcopy(messages) processed_messages = [] processed_messages_tokens = 0 - # calculate tokens for all messages - total_tokens = sum( - _count_tokens(msg["content"]) for msg in temp_messages if isinstance(msg.get("content"), (str, list)) - ) - for msg in reversed(temp_messages): # Some messages may not have content. if not isinstance(msg.get("content"), (str, list)): @@ -154,16 +176,24 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: processed_messages_tokens += msg_tokens processed_messages.insert(0, msg) - if total_tokens > processed_messages_tokens: - print( - colored( - f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}", - "yellow", - ) - ) - return processed_messages + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + pre_transform_messages_tokens = sum( + _count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg + ) + post_transform_messages_tokens = sum( + _count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg + ) + + if post_transform_messages_tokens < pre_transform_messages_tokens: + logs_str = ( + f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " + f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" + ) + return logs_str, True + return "No tokens were truncated.", False + def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]: if isinstance(contents, str): return self._truncate_tokens(contents, n_tokens) diff --git a/test/agentchat/contrib/capabilities/test_transform_messages.py b/test/agentchat/contrib/capabilities/test_transform_messages.py index ac0cdf58755b..4888b49f3274 100644 --- a/test/agentchat/contrib/capabilities/test_transform_messages.py +++ b/test/agentchat/contrib/capabilities/test_transform_messages.py @@ -1,4 +1,3 @@ -import copy import os import sys import tempfile @@ -7,7 +6,6 @@ import pytest import autogen -from autogen import token_count_utils from autogen.agentchat.contrib.capabilities.transform_messages import TransformMessages from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter @@ -18,106 +16,6 @@ from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 -def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int: - token_count = 0 - if isinstance(content, str): - token_count = token_count_utils.count_token(content) - elif isinstance(content, list): - for item in content: - token_count += _count_tokens(item.get("text", "")) - return token_count - - -def test_limit_token_transform(): - """ - Test the TokenLimitTransform capability. - """ - - messages = [ - {"role": "user", "content": "short string"}, - { - "role": "assistant", - "content": [{"type": "text", "text": "very very very very very very very very long string"}], - }, - ] - - # check if token limit per message is not exceeded. - max_tokens_per_message = 5 - token_limit_transform = MessageTokenLimiter(max_tokens_per_message=max_tokens_per_message) - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - for message in transformed_messages: - assert _count_tokens(message["content"]) <= max_tokens_per_message - - # check if total token limit is not exceeded. - max_tokens = 10 - token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens) - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - token_count = 0 - for message in transformed_messages: - token_count += _count_tokens(message["content"]) - - assert token_count <= max_tokens - assert len(transformed_messages) <= len(messages) - - # check if token limit per message works nicely with total token limit. - token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens, max_tokens_per_message=max_tokens_per_message) - - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - token_count = 0 - for message in transformed_messages: - token_count_local = _count_tokens(message["content"]) - token_count += token_count_local - assert token_count_local <= max_tokens_per_message - - assert token_count <= max_tokens - assert len(transformed_messages) <= len(messages) - - -def test_limit_token_transform_without_content(): - """Test the TokenLimitTransform with messages that don't have content.""" - - messages = [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] - - # check if token limit per message works nicely with total token limit. - token_limit_transform = MessageTokenLimiter(max_tokens=10, max_tokens_per_message=5) - - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - assert len(transformed_messages) == len(messages) - - -def test_limit_token_transform_total_token_count(): - """Tests if the TokenLimitTransform truncates without dropping messages.""" - messages = [{"role": "very very very very very"}] - - token_limit_transform = MessageTokenLimiter(max_tokens=1) - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - assert len(transformed_messages) == 1 - - -def test_max_message_history_length_transform(): - """ - Test the MessageHistoryLimiter capability to limit the number of messages. - """ - messages = [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, - {"role": "user", "content": "how"}, - {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, - ] - - max_messages = 2 - messages_limiter = MessageHistoryLimiter(max_messages=max_messages) - transformed_messages = messages_limiter.apply_transform(copy.deepcopy(messages)) - - assert len(transformed_messages) == max_messages - assert transformed_messages == messages[max_messages:] - - @pytest.mark.skipif(skip_openai, reason="Requested to skip openai test.") def test_transform_messages_capability(): """Test the TransformMessages capability to handle long contexts. @@ -172,6 +70,4 @@ def test_transform_messages_capability(): if __name__ == "__main__": - test_limit_token_transform() - test_max_message_history_length_transform() test_transform_messages_capability() diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py new file mode 100644 index 000000000000..1a929e4c6ba1 --- /dev/null +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -0,0 +1,122 @@ +import copy +from typing import Dict, List + +import pytest + +from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter, _count_tokens + + +def get_long_messages() -> List[Dict]: + return [ + {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, + {"role": "user", "content": "very very very very very very long string"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, + {"role": "user", "content": "how"}, + ] + + +def get_short_messages() -> List[Dict]: + return [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, + {"role": "user", "content": "how"}, + ] + + +def get_no_content_messages() -> List[Dict]: + return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] + + +@pytest.fixture +def message_history_limiter() -> MessageHistoryLimiter: + return MessageHistoryLimiter(max_messages=3) + + +@pytest.fixture +def message_token_limiter() -> MessageTokenLimiter: + return MessageTokenLimiter(max_tokens_per_message=3) + + +# MessageHistoryLimiter tests + + +@pytest.mark.parametrize( + "messages, expected_messages_len", + [(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2)], +) +def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len): + transformed_messages = message_history_limiter.apply_transform(messages) + assert len(transformed_messages) == expected_messages_len + + +@pytest.mark.parametrize( + "messages, expected_logs, expected_effect", + [ + (get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True), + (get_short_messages(), "No messages were removed.", False), + (get_no_content_messages(), "No messages were removed.", False), + ], +) +def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect): + pre_transform_messages = copy.deepcopy(messages) + transformed_messages = message_history_limiter.apply_transform(messages) + logs_str, had_effect = message_history_limiter.get_logs(pre_transform_messages, transformed_messages) + assert had_effect == expected_effect + assert logs_str == expected_logs + + +# MessageTokenLimiter tests + + +@pytest.mark.parametrize( + "messages, expected_token_count, expected_messages_len", + [(get_long_messages(), 9, 5), (get_short_messages(), 3, 3), (get_no_content_messages(), 0, 2)], +) +def test_message_token_limiter_apply_transform( + message_token_limiter, messages, expected_token_count, expected_messages_len +): + transformed_messages = message_token_limiter.apply_transform(messages) + assert ( + sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count + ) + assert len(transformed_messages) == expected_messages_len + + +@pytest.mark.parametrize( + "messages, expected_logs, expected_effect", + [ + (get_long_messages(), "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True), + (get_short_messages(), "No tokens were truncated.", False), + (get_no_content_messages(), "No tokens were truncated.", False), + ], +) +def test_message_token_limiter_get_logs(message_token_limiter, messages, expected_logs, expected_effect): + pre_transform_messages = copy.deepcopy(messages) + transformed_messages = message_token_limiter.apply_transform(messages) + logs_str, had_effect = message_token_limiter.get_logs(pre_transform_messages, transformed_messages) + assert had_effect == expected_effect + assert logs_str == expected_logs + + +if __name__ == "__main__": + long_messages = get_long_messages() + short_messages = get_short_messages() + message_history_limiter = MessageHistoryLimiter(max_messages=3) + message_token_limiter = MessageTokenLimiter(max_tokens_per_message=3) + + # Call the MessageHistoryLimiter tests + test_message_history_limiter_apply_transform(message_history_limiter, long_messages, 3) + test_message_history_limiter_apply_transform(message_history_limiter, short_messages, 3) + test_message_history_limiter_get_logs( + message_history_limiter, long_messages, "Removed 2 messages. Number of messages reduced from 5 to 3.", True + ) + test_message_history_limiter_get_logs(message_history_limiter, short_messages, "No messages were removed.", False) + + # Call the MessageTokenLimiter tests + test_message_token_limiter_apply_transform(message_token_limiter, long_messages, 9) + test_message_token_limiter_apply_transform(message_token_limiter, short_messages, 3) + test_message_token_limiter_get_logs( + message_token_limiter, long_messages, "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True + ) + test_message_token_limiter_get_logs(message_token_limiter, short_messages, "No tokens were truncated.", False)