From 07f82a7e8efca2bbde532cc58296b44e80711719 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Sat, 6 Apr 2024 18:33:20 +0300 Subject: [PATCH 01/13] Standardize printing of MessageTransforms --- .../capabilities/transform_messages.py | 25 +++++------- .../contrib/capabilities/transforms.py | 39 +++++++++++++++---- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 46c8d4e0a4d3..784619abb66c 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -43,12 +43,14 @@ class TransformMessages: ``` """ - def __init__(self, *, transforms: List[MessageTransform] = []): + def __init__(self, *, transforms: List[MessageTransform] = [], to_print_stats: bool = True): """ Args: transforms: A list of message transformations to apply. + to_print_stats: Whether to print stats of each transformation or not. """ self._transforms = transforms + self._to_print_stats = to_print_stats def add_to_agent(self, agent: ConversableAgent): """Adds the message transformations capability to the specified ConversableAgent. @@ -69,23 +71,14 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]: temp_messages.pop(0) for transform in self._transforms: - temp_messages = transform.apply_transform(temp_messages) + if self._to_print_stats: + pre_transform_messages = copy.deepcopy(temp_messages) + temp_messages = transform.apply_transform(temp_messages) + transform.print_stats(pre_transform_messages, temp_messages) + else: + temp_messages = transform.apply_transform(temp_messages) if system_message: temp_messages.insert(0, system_message) - self._print_stats(messages, temp_messages) - return temp_messages - - def _print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): - pre_transform_messages_len = len(pre_transform_messages) - post_transform_messages_len = len(post_transform_messages) - - if pre_transform_messages_len < post_transform_messages_len: - print( - colored( - f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.", - "yellow", - ) - ) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index d87963ec82eb..34056371c5a5 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -25,6 +25,15 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: """ ... + def print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): + """Prints stats of the executed transformation. + + Args: + pre_transform_messages: A list of dictionaries representing messages before the transformation. + post_transform_messages: A list of dictionaries representig messages after the transformation. + """ + ... + class MessageHistoryLimiter: """Limits the number of messages considered by an agent for response generation. @@ -60,6 +69,19 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return messages[-self._max_messages :] + def print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): + pre_transform_messages_len = len(pre_transform_messages) + post_transform_messages_len = len(post_transform_messages) + + if post_transform_messages_len < pre_transform_messages_len: + print( + colored( + f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " \ + f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.", + "yellow", + ) + ) + def _validate_max_messages(self, max_messages: Optional[int]): if max_messages is not None and max_messages < 1: raise ValueError("max_messages must be None or greater than 1") @@ -124,9 +146,6 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: processed_messages = [] processed_messages_tokens = 0 - # calculate tokens for all messages - total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages) - for msg in reversed(temp_messages): msg["content"] = self._truncate_str_to_tokens(msg["content"]) msg_tokens = _count_tokens(msg["content"]) @@ -139,16 +158,22 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: processed_messages_tokens += msg_tokens processed_messages.insert(0, msg) - if total_tokens > processed_messages_tokens: + + return processed_messages + + def print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): + pre_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in pre_transform_messages) + post_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in post_transform_messages) + + if post_transform_messages_tokens < pre_transform_messages_tokens: print( colored( - f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}", + f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " \ + f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}", "yellow", ) ) - return processed_messages - def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]: if isinstance(contents, str): return self._truncate_tokens(contents) From a0295240eddba2993571e23d88ce095a882b4074 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Sat, 6 Apr 2024 19:03:58 +0300 Subject: [PATCH 02/13] Fix pre-commit black failure --- autogen/agentchat/contrib/capabilities/transform_messages.py | 2 +- autogen/agentchat/contrib/capabilities/transforms.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 784619abb66c..306b792ef3a0 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -76,7 +76,7 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]: temp_messages = transform.apply_transform(temp_messages) transform.print_stats(pre_transform_messages, temp_messages) else: - temp_messages = transform.apply_transform(temp_messages) + temp_messages = transform.apply_transform(temp_messages) if system_message: temp_messages.insert(0, system_message) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 34056371c5a5..2dceeaff76ee 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -76,7 +76,7 @@ def print_stats(self, pre_transform_messages: List[Dict], post_transform_message if post_transform_messages_len < pre_transform_messages_len: print( colored( - f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " \ + f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.", "yellow", ) @@ -158,7 +158,6 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: processed_messages_tokens += msg_tokens processed_messages.insert(0, msg) - return processed_messages def print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): @@ -168,7 +167,7 @@ def print_stats(self, pre_transform_messages: List[Dict], post_transform_message if post_transform_messages_tokens < pre_transform_messages_tokens: print( colored( - f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " \ + f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}", "yellow", ) From 961475ad0ee766b59d9da1a58458fdb4cd177322 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Sat, 6 Apr 2024 19:34:33 +0300 Subject: [PATCH 03/13] Add test for transform_messages printing --- .../capabilities/test_transform_messages.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/agentchat/contrib/capabilities/test_transform_messages.py b/test/agentchat/contrib/capabilities/test_transform_messages.py index c5b7c1dcf2d2..82bfdeb2bda4 100644 --- a/test/agentchat/contrib/capabilities/test_transform_messages.py +++ b/test/agentchat/contrib/capabilities/test_transform_messages.py @@ -148,6 +148,34 @@ def test_transform_messages_capability(): assert False, f"Chat initiation failed with error {str(e)}" +def test_transform_messages_printing(capsys): + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, + {"role": "user", "content": "how"}, + {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, + {"role": "user", "content": "very very very very very very long string"}, + ] + + context_handling = TransformMessages( + transforms=[ + MessageHistoryLimiter(max_messages=3), + MessageTokenLimiter(max_tokens=1000, max_tokens_per_message=3), + ], + ) + + context_handling._transform_messages(messages) + + captured = capsys.readouterr() + + captured_output = captured.out.strip().split("\n") + + assert captured_output == [ + "Removed 2 messages. Number of messages reduced from 5 to 3.", + "Truncated 6 tokens. Number of tokens reduced from 13 to 7", + ] + + if __name__ == "__main__": test_limit_token_transform() test_max_message_history_length_transform() From 0136a82c30fef392e981507d2b40768f711e89d8 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Mon, 8 Apr 2024 11:45:55 +0300 Subject: [PATCH 04/13] Return str instead of printing --- .../capabilities/transform_messages.py | 4 +- .../contrib/capabilities/transforms.py | 42 +++--- .../capabilities/test_transform_messages.py | 28 ---- .../contrib/capabilities/test_transforms.py | 121 ++++++++++++++++++ 4 files changed, 148 insertions(+), 47 deletions(-) create mode 100644 test/agentchat/contrib/capabilities/test_transforms.py diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 306b792ef3a0..6c80234e2716 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -74,7 +74,9 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]: if self._to_print_stats: pre_transform_messages = copy.deepcopy(temp_messages) temp_messages = transform.apply_transform(temp_messages) - transform.print_stats(pre_transform_messages, temp_messages) + stats_str, had_effect = transform.get_stats_str(pre_transform_messages, temp_messages) + if had_effect: + print(stats_str) else: temp_messages = transform.apply_transform(temp_messages) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 2dceeaff76ee..07c2ec12654a 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -1,5 +1,6 @@ +import copy import sys -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union import tiktoken from termcolor import colored @@ -25,12 +26,19 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: """ ... - def print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): - """Prints stats of the executed transformation. + def get_stats_str( + self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict] + ) -> Tuple[str, bool]: + """Creates the string includin the stats of the transformation + + Alongside the string, it returns a boolean indicating whether the transformation had an effect or not. Args: pre_transform_messages: A list of dictionaries representing messages before the transformation. post_transform_messages: A list of dictionaries representig messages after the transformation. + + Returns: + A tuple with a string with the stats and a flag indicating whether the transformation had an effect or not. """ ... @@ -69,18 +77,17 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return messages[-self._max_messages :] - def print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): + def get_stats_str(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): pre_transform_messages_len = len(pre_transform_messages) post_transform_messages_len = len(post_transform_messages) if post_transform_messages_len < pre_transform_messages_len: - print( - colored( - f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " - f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.", - "yellow", - ) + stats_str = ( + f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " + f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}." ) + return stats_str, True + return "No messages were removed.", False def _validate_max_messages(self, max_messages: Optional[int]): if max_messages is not None and max_messages < 1: @@ -142,7 +149,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: assert self._max_tokens_per_message is not None assert self._max_tokens is not None - temp_messages = messages.copy() + temp_messages = copy.deepcopy(messages) processed_messages = [] processed_messages_tokens = 0 @@ -160,18 +167,17 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return processed_messages - def print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): + def get_stats_str(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): pre_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in pre_transform_messages) post_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in post_transform_messages) if post_transform_messages_tokens < pre_transform_messages_tokens: - print( - colored( - f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " - f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}", - "yellow", - ) + stats_str = ( + f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " + f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" ) + return stats_str, True + return "No tokens were truncated.", False def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]: if isinstance(contents, str): diff --git a/test/agentchat/contrib/capabilities/test_transform_messages.py b/test/agentchat/contrib/capabilities/test_transform_messages.py index 82bfdeb2bda4..c5b7c1dcf2d2 100644 --- a/test/agentchat/contrib/capabilities/test_transform_messages.py +++ b/test/agentchat/contrib/capabilities/test_transform_messages.py @@ -148,34 +148,6 @@ def test_transform_messages_capability(): assert False, f"Chat initiation failed with error {str(e)}" -def test_transform_messages_printing(capsys): - messages = [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, - {"role": "user", "content": "how"}, - {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, - {"role": "user", "content": "very very very very very very long string"}, - ] - - context_handling = TransformMessages( - transforms=[ - MessageHistoryLimiter(max_messages=3), - MessageTokenLimiter(max_tokens=1000, max_tokens_per_message=3), - ], - ) - - context_handling._transform_messages(messages) - - captured = capsys.readouterr() - - captured_output = captured.out.strip().split("\n") - - assert captured_output == [ - "Removed 2 messages. Number of messages reduced from 5 to 3.", - "Truncated 6 tokens. Number of tokens reduced from 13 to 7", - ] - - if __name__ == "__main__": test_limit_token_transform() test_max_message_history_length_transform() diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py new file mode 100644 index 000000000000..24edbd5f241d --- /dev/null +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -0,0 +1,121 @@ +import copy +from typing import Dict, List + +import pytest + +from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter, _count_tokens + + +@pytest.fixture +def long_messages() -> List[Dict]: + return [ + {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, + {"role": "user", "content": "very very very very very very long string"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, + {"role": "user", "content": "how"}, + ] + + +@pytest.fixture +def short_messages() -> List[Dict]: + return [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, + {"role": "user", "content": "how"}, + ] + + +@pytest.fixture +def message_history_limiter() -> MessageHistoryLimiter: + return MessageHistoryLimiter(max_messages=3) + + +@pytest.fixture +def message_token_limiter() -> MessageTokenLimiter: + return MessageTokenLimiter(max_tokens_per_message=3) + + +# MessageHistoryLimiter tests + + +def test_MessageHistoryLimiter_apply_transform_long(message_history_limiter, long_messages): + transformed_messages = message_history_limiter.apply_transform(long_messages) + assert len(transformed_messages) == 3 + + +def test_MessageHistoryLimiter_apply_transform_short(message_history_limiter, short_messages): + transformed_messages = message_history_limiter.apply_transform(short_messages) + assert len(transformed_messages) == 3 + + +def test_MessageHistoryLimiter_get_stats_str_long(message_history_limiter, long_messages): + pre_transform_messages = copy.deepcopy(long_messages) + transformed_messages = message_history_limiter.apply_transform(long_messages) + stats_str, had_effect = message_history_limiter.get_stats_str(pre_transform_messages, transformed_messages) + assert had_effect + assert stats_str == "Removed 2 messages. Number of messages reduced from 5 to 3." + + +def test_MessageHistoryLimiter_get_stats_str_short(message_history_limiter, short_messages): + pre_transform_messages = copy.deepcopy(short_messages) + transformed_messages = message_history_limiter.apply_transform(short_messages) + stats_str, had_effect = message_history_limiter.get_stats_str(pre_transform_messages, transformed_messages) + assert not had_effect + assert stats_str == "No messages were removed." + + +# MessageTokenLimiter tests + + +def test_MessageTokenLimiter_apply_transform_long(message_token_limiter, long_messages): + transformed_messages = message_token_limiter.apply_transform(long_messages) + assert sum(_count_tokens(msg["content"]) for msg in transformed_messages) == 9 + + +def test_MessageTokenLimiter_apply_transform_short(message_token_limiter, short_messages): + transformed_messages = message_token_limiter.apply_transform(short_messages) + assert sum(_count_tokens(msg["content"]) for msg in transformed_messages) == 3 + + +def test_MessageTokenLimiter_get_stats_str_long(message_token_limiter, long_messages): + pre_transform_messages = copy.deepcopy(long_messages) + transformed_messages = message_token_limiter.apply_transform(long_messages) + stats_str, had_effect = message_token_limiter.get_stats_str(pre_transform_messages, transformed_messages) + assert had_effect + assert stats_str == "Truncated 6 tokens. Number of tokens reduced from 15 to 9" + + +def test_MessageTokenLimiter_get_stats_str_short(message_token_limiter, short_messages): + pre_transform_messages = copy.deepcopy(short_messages) + transformed_messages = message_token_limiter.apply_transform(short_messages) + stats_str, had_effect = message_token_limiter.get_stats_str(pre_transform_messages, transformed_messages) + assert not had_effect + assert stats_str == "No tokens were truncated." + + +if __name__ == "__main__": + long_messages = [ + {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, + {"role": "user", "content": "very very very very very very long string"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, + {"role": "user", "content": "how"}, + ] + short_messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, + {"role": "user", "content": "how"}, + ] + message_history_limiter = MessageHistoryLimiter(max_messages=3) + message_token_limiter = MessageTokenLimiter(max_tokens_per_message=3) + + test_MessageHistoryLimiter_apply_transform_long(message_history_limiter, long_messages) + test_MessageHistoryLimiter_apply_transform_short(message_history_limiter, short_messages) + test_MessageHistoryLimiter_get_stats_str_long(message_history_limiter, long_messages) + test_MessageHistoryLimiter_get_stats_str_short(message_history_limiter, short_messages) + + test_MessageTokenLimiter_apply_transform_long(message_token_limiter, long_messages) + test_MessageTokenLimiter_apply_transform_short(message_token_limiter, short_messages) + test_MessageTokenLimiter_get_stats_str_long(message_token_limiter, long_messages) + test_MessageTokenLimiter_get_stats_str_short(message_token_limiter, short_messages) From e8f9ae37094b0b00ba44fed0998040ea11873762 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Mon, 8 Apr 2024 11:51:35 +0300 Subject: [PATCH 05/13] Rename to_print_stats to verbose --- .../agentchat/contrib/capabilities/transform_messages.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 6c80234e2716..7a3eeb5a4ddd 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -43,14 +43,14 @@ class TransformMessages: ``` """ - def __init__(self, *, transforms: List[MessageTransform] = [], to_print_stats: bool = True): + def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True): """ Args: transforms: A list of message transformations to apply. - to_print_stats: Whether to print stats of each transformation or not. + verbose: Whether to print stats of each transformation or not. """ self._transforms = transforms - self._to_print_stats = to_print_stats + self._verbose = verbose def add_to_agent(self, agent: ConversableAgent): """Adds the message transformations capability to the specified ConversableAgent. @@ -71,7 +71,7 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]: temp_messages.pop(0) for transform in self._transforms: - if self._to_print_stats: + if self._verbose: pre_transform_messages = copy.deepcopy(temp_messages) temp_messages = transform.apply_transform(temp_messages) stats_str, had_effect = transform.get_stats_str(pre_transform_messages, temp_messages) From 1898862e2f9a2534c05db3b4828d59cf1ef0038f Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Tue, 9 Apr 2024 10:34:45 +0300 Subject: [PATCH 06/13] Cleanup --- .../capabilities/transform_messages.py | 21 ++-- .../contrib/capabilities/transforms.py | 10 +- .../contrib/capabilities/test_transforms.py | 107 +++++++++--------- 3 files changed, 68 insertions(+), 70 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 7a3eeb5a4ddd..9bfb9dc967e6 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -1,10 +1,9 @@ import copy from typing import Dict, List -from termcolor import colored - from autogen import ConversableAgent +from ....formatting_utils import colored from .transforms import MessageTransform @@ -63,24 +62,24 @@ def add_to_agent(self, agent: ConversableAgent): agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages) def _transform_messages(self, messages: List[Dict]) -> List[Dict]: - temp_messages = copy.deepcopy(messages) + post_transform_messages = copy.deepcopy(messages) system_message = None if messages[0]["role"] == "system": system_message = copy.deepcopy(messages[0]) - temp_messages.pop(0) + post_transform_messages.pop(0) for transform in self._transforms: if self._verbose: - pre_transform_messages = copy.deepcopy(temp_messages) - temp_messages = transform.apply_transform(temp_messages) - stats_str, had_effect = transform.get_stats_str(pre_transform_messages, temp_messages) + pre_transform_messages = copy.deepcopy(post_transform_messages) + post_transform_messages = transform.apply_transform(post_transform_messages) + stats_str, had_effect = transform.get_stats(pre_transform_messages, post_transform_messages) if had_effect: - print(stats_str) + print(colored(stats_str, "yellow")) else: - temp_messages = transform.apply_transform(temp_messages) + post_transform_messages = transform.apply_transform(post_transform_messages) if system_message: - temp_messages.insert(0, system_message) + post_transform_messages.insert(0, system_message) - return temp_messages + return post_transform_messages diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 07c2ec12654a..b47a1049ed93 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -26,10 +26,8 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: """ ... - def get_stats_str( - self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict] - ) -> Tuple[str, bool]: - """Creates the string includin the stats of the transformation + def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + """Creates the string including the stats of the transformation Alongside the string, it returns a boolean indicating whether the transformation had an effect or not. @@ -77,7 +75,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return messages[-self._max_messages :] - def get_stats_str(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): + def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: pre_transform_messages_len = len(pre_transform_messages) post_transform_messages_len = len(post_transform_messages) @@ -167,7 +165,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return processed_messages - def get_stats_str(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]): + def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: pre_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in pre_transform_messages) post_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in post_transform_messages) diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index 24edbd5f241d..407abaa33551 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -2,6 +2,7 @@ from typing import Dict, List import pytest +from pytest_lazyfixture import lazy_fixture from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter, _count_tokens @@ -39,59 +40,53 @@ def message_token_limiter() -> MessageTokenLimiter: # MessageHistoryLimiter tests -def test_MessageHistoryLimiter_apply_transform_long(message_history_limiter, long_messages): - transformed_messages = message_history_limiter.apply_transform(long_messages) - assert len(transformed_messages) == 3 +@pytest.mark.parametrize( + "messages, expected_len", [(lazy_fixture("long_messages"), 3), (lazy_fixture("short_messages"), 3)] +) +def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_len): + transformed_messages = message_history_limiter.apply_transform(messages) + assert len(transformed_messages) == expected_len -def test_MessageHistoryLimiter_apply_transform_short(message_history_limiter, short_messages): - transformed_messages = message_history_limiter.apply_transform(short_messages) - assert len(transformed_messages) == 3 - - -def test_MessageHistoryLimiter_get_stats_str_long(message_history_limiter, long_messages): - pre_transform_messages = copy.deepcopy(long_messages) - transformed_messages = message_history_limiter.apply_transform(long_messages) - stats_str, had_effect = message_history_limiter.get_stats_str(pre_transform_messages, transformed_messages) - assert had_effect - assert stats_str == "Removed 2 messages. Number of messages reduced from 5 to 3." - - -def test_MessageHistoryLimiter_get_stats_str_short(message_history_limiter, short_messages): - pre_transform_messages = copy.deepcopy(short_messages) - transformed_messages = message_history_limiter.apply_transform(short_messages) - stats_str, had_effect = message_history_limiter.get_stats_str(pre_transform_messages, transformed_messages) - assert not had_effect - assert stats_str == "No messages were removed." +@pytest.mark.parametrize( + "messages, expected_stats, expected_effect", + [ + (lazy_fixture("long_messages"), "Removed 2 messages. Number of messages reduced from 5 to 3.", True), + (lazy_fixture("short_messages"), "No messages were removed.", False), + ], +) +def test_message_history_limiter_get_stats(message_history_limiter, messages, expected_stats, expected_effect): + pre_transform_messages = copy.deepcopy(messages) + transformed_messages = message_history_limiter.apply_transform(messages) + stats_str, had_effect = message_history_limiter.get_stats(pre_transform_messages, transformed_messages) + assert had_effect == expected_effect + assert stats_str == expected_stats # MessageTokenLimiter tests -def test_MessageTokenLimiter_apply_transform_long(message_token_limiter, long_messages): - transformed_messages = message_token_limiter.apply_transform(long_messages) - assert sum(_count_tokens(msg["content"]) for msg in transformed_messages) == 9 - +@pytest.mark.parametrize( + "messages, expected_token_count", [(lazy_fixture("long_messages"), 9), (lazy_fixture("short_messages"), 3)] +) +def test_message_token_limiter_apply_transform(message_token_limiter, messages, expected_token_count): + transformed_messages = message_token_limiter.apply_transform(messages) + assert sum(_count_tokens(msg["content"]) for msg in transformed_messages) == expected_token_count -def test_MessageTokenLimiter_apply_transform_short(message_token_limiter, short_messages): - transformed_messages = message_token_limiter.apply_transform(short_messages) - assert sum(_count_tokens(msg["content"]) for msg in transformed_messages) == 3 - -def test_MessageTokenLimiter_get_stats_str_long(message_token_limiter, long_messages): - pre_transform_messages = copy.deepcopy(long_messages) - transformed_messages = message_token_limiter.apply_transform(long_messages) - stats_str, had_effect = message_token_limiter.get_stats_str(pre_transform_messages, transformed_messages) - assert had_effect - assert stats_str == "Truncated 6 tokens. Number of tokens reduced from 15 to 9" - - -def test_MessageTokenLimiter_get_stats_str_short(message_token_limiter, short_messages): - pre_transform_messages = copy.deepcopy(short_messages) - transformed_messages = message_token_limiter.apply_transform(short_messages) - stats_str, had_effect = message_token_limiter.get_stats_str(pre_transform_messages, transformed_messages) - assert not had_effect - assert stats_str == "No tokens were truncated." +@pytest.mark.parametrize( + "messages, expected_stats, expected_effect", + [ + (lazy_fixture("long_messages"), "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True), + (lazy_fixture("short_messages"), "No tokens were truncated.", False), + ], +) +def test_message_token_limiter_get_stats(message_token_limiter, messages, expected_stats, expected_effect): + pre_transform_messages = copy.deepcopy(messages) + transformed_messages = message_token_limiter.apply_transform(messages) + stats_str, had_effect = message_token_limiter.get_stats(pre_transform_messages, transformed_messages) + assert had_effect == expected_effect + assert stats_str == expected_stats if __name__ == "__main__": @@ -110,12 +105,18 @@ def test_MessageTokenLimiter_get_stats_str_short(message_token_limiter, short_me message_history_limiter = MessageHistoryLimiter(max_messages=3) message_token_limiter = MessageTokenLimiter(max_tokens_per_message=3) - test_MessageHistoryLimiter_apply_transform_long(message_history_limiter, long_messages) - test_MessageHistoryLimiter_apply_transform_short(message_history_limiter, short_messages) - test_MessageHistoryLimiter_get_stats_str_long(message_history_limiter, long_messages) - test_MessageHistoryLimiter_get_stats_str_short(message_history_limiter, short_messages) - - test_MessageTokenLimiter_apply_transform_long(message_token_limiter, long_messages) - test_MessageTokenLimiter_apply_transform_short(message_token_limiter, short_messages) - test_MessageTokenLimiter_get_stats_str_long(message_token_limiter, long_messages) - test_MessageTokenLimiter_get_stats_str_short(message_token_limiter, short_messages) + # Call the MessageHistoryLimiter tests + test_message_history_limiter_apply_transform(message_history_limiter, long_messages, 3) + test_message_history_limiter_apply_transform(message_history_limiter, short_messages, 3) + test_message_history_limiter_get_stats( + message_history_limiter, long_messages, "Removed 2 messages. Number of messages reduced from 5 to 3.", True + ) + test_message_history_limiter_get_stats(message_history_limiter, short_messages, "No messages were removed.", False) + + # Call the MessageTokenLimiter tests + test_message_token_limiter_apply_transform(message_token_limiter, long_messages, 9) + test_message_token_limiter_apply_transform(message_token_limiter, short_messages, 3) + test_message_token_limiter_get_stats( + message_token_limiter, long_messages, "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True + ) + test_message_token_limiter_get_stats(message_token_limiter, short_messages, "No tokens were truncated.", False) From ffbb67bb08b792600e4f9662a2fc96c770c0bf07 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Tue, 9 Apr 2024 10:41:12 +0300 Subject: [PATCH 07/13] t i# This is a combination of 3 commits. Update requirements --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 650d785c1e06..88e5a650277a 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ # Disallowing 2.6.0 can be removed when this is fixed https://github.com/pydantic/pydantic/issues/8705 "pydantic>=1.10,<3,!=2.6.0", # could be both V1 and V2 "docker", + "pytest-lazy-fixture", ] jupyter_executor = [ From a44a4d3d0d6855be6590b05b361294606e4c006d Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Tue, 9 Apr 2024 14:28:46 +0300 Subject: [PATCH 08/13] Remove lazy-fixture --- setup.py | 1 - .../contrib/capabilities/test_transforms.py | 37 +++++-------------- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 88e5a650277a..650d785c1e06 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,6 @@ # Disallowing 2.6.0 can be removed when this is fixed https://github.com/pydantic/pydantic/issues/8705 "pydantic>=1.10,<3,!=2.6.0", # could be both V1 and V2 "docker", - "pytest-lazy-fixture", ] jupyter_executor = [ diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index 407abaa33551..5832426e6200 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -2,13 +2,11 @@ from typing import Dict, List import pytest -from pytest_lazyfixture import lazy_fixture from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter, _count_tokens -@pytest.fixture -def long_messages() -> List[Dict]: +def get_long_messages() -> List[Dict]: return [ {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, {"role": "user", "content": "very very very very very very long string"}, @@ -18,8 +16,7 @@ def long_messages() -> List[Dict]: ] -@pytest.fixture -def short_messages() -> List[Dict]: +def get_short_messages() -> List[Dict]: return [ {"role": "user", "content": "hello"}, {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, @@ -40,9 +37,7 @@ def message_token_limiter() -> MessageTokenLimiter: # MessageHistoryLimiter tests -@pytest.mark.parametrize( - "messages, expected_len", [(lazy_fixture("long_messages"), 3), (lazy_fixture("short_messages"), 3)] -) +@pytest.mark.parametrize("messages, expected_len", [(get_long_messages(), 3), (get_short_messages(), 3)]) def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_len): transformed_messages = message_history_limiter.apply_transform(messages) assert len(transformed_messages) == expected_len @@ -51,8 +46,8 @@ def test_message_history_limiter_apply_transform(message_history_limiter, messag @pytest.mark.parametrize( "messages, expected_stats, expected_effect", [ - (lazy_fixture("long_messages"), "Removed 2 messages. Number of messages reduced from 5 to 3.", True), - (lazy_fixture("short_messages"), "No messages were removed.", False), + (get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True), + (get_short_messages(), "No messages were removed.", False), ], ) def test_message_history_limiter_get_stats(message_history_limiter, messages, expected_stats, expected_effect): @@ -66,9 +61,7 @@ def test_message_history_limiter_get_stats(message_history_limiter, messages, ex # MessageTokenLimiter tests -@pytest.mark.parametrize( - "messages, expected_token_count", [(lazy_fixture("long_messages"), 9), (lazy_fixture("short_messages"), 3)] -) +@pytest.mark.parametrize("messages, expected_token_count", [(get_long_messages(), 9), (get_short_messages(), 3)]) def test_message_token_limiter_apply_transform(message_token_limiter, messages, expected_token_count): transformed_messages = message_token_limiter.apply_transform(messages) assert sum(_count_tokens(msg["content"]) for msg in transformed_messages) == expected_token_count @@ -77,8 +70,8 @@ def test_message_token_limiter_apply_transform(message_token_limiter, messages, @pytest.mark.parametrize( "messages, expected_stats, expected_effect", [ - (lazy_fixture("long_messages"), "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True), - (lazy_fixture("short_messages"), "No tokens were truncated.", False), + (get_long_messages(), "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True), + (get_short_messages(), "No tokens were truncated.", False), ], ) def test_message_token_limiter_get_stats(message_token_limiter, messages, expected_stats, expected_effect): @@ -90,18 +83,8 @@ def test_message_token_limiter_get_stats(message_token_limiter, messages, expect if __name__ == "__main__": - long_messages = [ - {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, - {"role": "user", "content": "very very very very very very long string"}, - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, - {"role": "user", "content": "how"}, - ] - short_messages = [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, - {"role": "user", "content": "how"}, - ] + long_messages = get_long_messages() + short_messages = get_short_messages() message_history_limiter = MessageHistoryLimiter(max_messages=3) message_token_limiter = MessageTokenLimiter(max_tokens_per_message=3) From cbad5a0d4b320d8d006357ea152149af73be5988 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Fri, 12 Apr 2024 10:09:26 +0300 Subject: [PATCH 09/13] Avoid calling apply_transform in two code paths --- .../contrib/capabilities/transform_messages.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 9bfb9dc967e6..38bfe4c93e5a 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -70,14 +70,16 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]: post_transform_messages.pop(0) for transform in self._transforms: + # deepcopy in case pre_transform_messages will late be used for stats + pre_transform_messages = ( + copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages + ) + post_transform_messages = transform.apply_transform(pre_transform_messages) + if self._verbose: - pre_transform_messages = copy.deepcopy(post_transform_messages) - post_transform_messages = transform.apply_transform(post_transform_messages) stats_str, had_effect = transform.get_stats(pre_transform_messages, post_transform_messages) if had_effect: print(colored(stats_str, "yellow")) - else: - post_transform_messages = transform.apply_transform(post_transform_messages) if system_message: post_transform_messages.insert(0, system_message) From 7d990550bb9d6b8baebdfcf08665e8ab4315c076 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Fri, 12 Apr 2024 10:16:02 +0300 Subject: [PATCH 10/13] Format --- autogen/agentchat/contrib/capabilities/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 03abb4a124b7..034232381ef8 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -177,6 +177,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: processed_messages.insert(0, msg) return processed_messages + def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: pre_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in pre_transform_messages) post_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in post_transform_messages) @@ -189,7 +190,6 @@ def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: return stats_str, True return "No tokens were truncated.", False - def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]: if isinstance(contents, str): return self._truncate_tokens(contents, n_tokens) From b5b4f58d7c70f16bd4e62fdf3d3a98b185d73657 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Sat, 13 Apr 2024 15:20:22 +0300 Subject: [PATCH 11/13] Replace stats with logs --- .../capabilities/transform_messages.py | 8 +++---- .../contrib/capabilities/transforms.py | 18 +++++++------- .../contrib/capabilities/test_transforms.py | 24 +++++++++---------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py index 38bfe4c93e5a..e96dc39fa7bc 100644 --- a/autogen/agentchat/contrib/capabilities/transform_messages.py +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -46,7 +46,7 @@ def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = T """ Args: transforms: A list of message transformations to apply. - verbose: Whether to print stats of each transformation or not. + verbose: Whether to print logs of each transformation or not. """ self._transforms = transforms self._verbose = verbose @@ -70,16 +70,16 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]: post_transform_messages.pop(0) for transform in self._transforms: - # deepcopy in case pre_transform_messages will late be used for stats + # deepcopy in case pre_transform_messages will later be used for logs printing pre_transform_messages = ( copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages ) post_transform_messages = transform.apply_transform(pre_transform_messages) if self._verbose: - stats_str, had_effect = transform.get_stats(pre_transform_messages, post_transform_messages) + logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages) if had_effect: - print(colored(stats_str, "yellow")) + print(colored(logs_str, "yellow")) if system_message: post_transform_messages.insert(0, system_message) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 034232381ef8..4e6eba8d6851 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -26,8 +26,8 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: """ ... - def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: - """Creates the string including the stats of the transformation + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + """Creates the string including the logs of the transformation Alongside the string, it returns a boolean indicating whether the transformation had an effect or not. @@ -36,7 +36,7 @@ def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: post_transform_messages: A list of dictionaries representig messages after the transformation. Returns: - A tuple with a string with the stats and a flag indicating whether the transformation had an effect or not. + A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not. """ ... @@ -75,16 +75,16 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return messages[-self._max_messages :] - def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: pre_transform_messages_len = len(pre_transform_messages) post_transform_messages_len = len(post_transform_messages) if post_transform_messages_len < pre_transform_messages_len: - stats_str = ( + logs_str = ( f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}." ) - return stats_str, True + return logs_str, True return "No messages were removed.", False def _validate_max_messages(self, max_messages: Optional[int]): @@ -178,16 +178,16 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return processed_messages - def get_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: pre_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in pre_transform_messages) post_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in post_transform_messages) if post_transform_messages_tokens < pre_transform_messages_tokens: - stats_str = ( + logs_str = ( f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" ) - return stats_str, True + return logs_str, True return "No tokens were truncated.", False def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]: diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index 5832426e6200..cc4d8c905c4f 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -44,18 +44,18 @@ def test_message_history_limiter_apply_transform(message_history_limiter, messag @pytest.mark.parametrize( - "messages, expected_stats, expected_effect", + "messages, expected_logs, expected_effect", [ (get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True), (get_short_messages(), "No messages were removed.", False), ], ) -def test_message_history_limiter_get_stats(message_history_limiter, messages, expected_stats, expected_effect): +def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect): pre_transform_messages = copy.deepcopy(messages) transformed_messages = message_history_limiter.apply_transform(messages) - stats_str, had_effect = message_history_limiter.get_stats(pre_transform_messages, transformed_messages) + logs_str, had_effect = message_history_limiter.get_logs(pre_transform_messages, transformed_messages) assert had_effect == expected_effect - assert stats_str == expected_stats + assert logs_str == expected_logs # MessageTokenLimiter tests @@ -68,18 +68,18 @@ def test_message_token_limiter_apply_transform(message_token_limiter, messages, @pytest.mark.parametrize( - "messages, expected_stats, expected_effect", + "messages, expected_logs, expected_effect", [ (get_long_messages(), "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True), (get_short_messages(), "No tokens were truncated.", False), ], ) -def test_message_token_limiter_get_stats(message_token_limiter, messages, expected_stats, expected_effect): +def test_message_token_limiter_get_logs(message_token_limiter, messages, expected_logs, expected_effect): pre_transform_messages = copy.deepcopy(messages) transformed_messages = message_token_limiter.apply_transform(messages) - stats_str, had_effect = message_token_limiter.get_stats(pre_transform_messages, transformed_messages) + logs_str, had_effect = message_token_limiter.get_logs(pre_transform_messages, transformed_messages) assert had_effect == expected_effect - assert stats_str == expected_stats + assert logs_str == expected_logs if __name__ == "__main__": @@ -91,15 +91,15 @@ def test_message_token_limiter_get_stats(message_token_limiter, messages, expect # Call the MessageHistoryLimiter tests test_message_history_limiter_apply_transform(message_history_limiter, long_messages, 3) test_message_history_limiter_apply_transform(message_history_limiter, short_messages, 3) - test_message_history_limiter_get_stats( + test_message_history_limiter_get_logs( message_history_limiter, long_messages, "Removed 2 messages. Number of messages reduced from 5 to 3.", True ) - test_message_history_limiter_get_stats(message_history_limiter, short_messages, "No messages were removed.", False) + test_message_history_limiter_get_logs(message_history_limiter, short_messages, "No messages were removed.", False) # Call the MessageTokenLimiter tests test_message_token_limiter_apply_transform(message_token_limiter, long_messages, 9) test_message_token_limiter_apply_transform(message_token_limiter, short_messages, 3) - test_message_token_limiter_get_stats( + test_message_token_limiter_get_logs( message_token_limiter, long_messages, "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True ) - test_message_token_limiter_get_stats(message_token_limiter, short_messages, "No tokens were truncated.", False) + test_message_token_limiter_get_logs(message_token_limiter, short_messages, "No tokens were truncated.", False) From 67198d6d3e3d2b58c9ec17d756fbd0aecf0f5b70 Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Sat, 13 Apr 2024 17:58:11 +0300 Subject: [PATCH 12/13] Handle no content messages in TokenLimiter get_logs() --- autogen/agentchat/contrib/capabilities/transforms.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 4e6eba8d6851..6dc1d59fe9c7 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -179,8 +179,12 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return processed_messages def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: - pre_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in pre_transform_messages) - post_transform_messages_tokens = sum(_count_tokens(msg["content"]) for msg in post_transform_messages) + pre_transform_messages_tokens = sum( + _count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg + ) + post_transform_messages_tokens = sum( + _count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg + ) if post_transform_messages_tokens < pre_transform_messages_tokens: logs_str = ( From 6d1e4afaeddb1daaa364512cfe7fa89f2ed967ca Mon Sep 17 00:00:00 2001 From: giorgossideris Date: Sat, 13 Apr 2024 18:07:06 +0300 Subject: [PATCH 13/13] Move tests from test_transform_messages to test_transforms --- .../capabilities/test_transform_messages.py | 104 ------------------ .../contrib/capabilities/test_transforms.py | 29 ++++- 2 files changed, 23 insertions(+), 110 deletions(-) diff --git a/test/agentchat/contrib/capabilities/test_transform_messages.py b/test/agentchat/contrib/capabilities/test_transform_messages.py index ac0cdf58755b..4888b49f3274 100644 --- a/test/agentchat/contrib/capabilities/test_transform_messages.py +++ b/test/agentchat/contrib/capabilities/test_transform_messages.py @@ -1,4 +1,3 @@ -import copy import os import sys import tempfile @@ -7,7 +6,6 @@ import pytest import autogen -from autogen import token_count_utils from autogen.agentchat.contrib.capabilities.transform_messages import TransformMessages from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter @@ -18,106 +16,6 @@ from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 -def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int: - token_count = 0 - if isinstance(content, str): - token_count = token_count_utils.count_token(content) - elif isinstance(content, list): - for item in content: - token_count += _count_tokens(item.get("text", "")) - return token_count - - -def test_limit_token_transform(): - """ - Test the TokenLimitTransform capability. - """ - - messages = [ - {"role": "user", "content": "short string"}, - { - "role": "assistant", - "content": [{"type": "text", "text": "very very very very very very very very long string"}], - }, - ] - - # check if token limit per message is not exceeded. - max_tokens_per_message = 5 - token_limit_transform = MessageTokenLimiter(max_tokens_per_message=max_tokens_per_message) - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - for message in transformed_messages: - assert _count_tokens(message["content"]) <= max_tokens_per_message - - # check if total token limit is not exceeded. - max_tokens = 10 - token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens) - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - token_count = 0 - for message in transformed_messages: - token_count += _count_tokens(message["content"]) - - assert token_count <= max_tokens - assert len(transformed_messages) <= len(messages) - - # check if token limit per message works nicely with total token limit. - token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens, max_tokens_per_message=max_tokens_per_message) - - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - token_count = 0 - for message in transformed_messages: - token_count_local = _count_tokens(message["content"]) - token_count += token_count_local - assert token_count_local <= max_tokens_per_message - - assert token_count <= max_tokens - assert len(transformed_messages) <= len(messages) - - -def test_limit_token_transform_without_content(): - """Test the TokenLimitTransform with messages that don't have content.""" - - messages = [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] - - # check if token limit per message works nicely with total token limit. - token_limit_transform = MessageTokenLimiter(max_tokens=10, max_tokens_per_message=5) - - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - assert len(transformed_messages) == len(messages) - - -def test_limit_token_transform_total_token_count(): - """Tests if the TokenLimitTransform truncates without dropping messages.""" - messages = [{"role": "very very very very very"}] - - token_limit_transform = MessageTokenLimiter(max_tokens=1) - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) - - assert len(transformed_messages) == 1 - - -def test_max_message_history_length_transform(): - """ - Test the MessageHistoryLimiter capability to limit the number of messages. - """ - messages = [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, - {"role": "user", "content": "how"}, - {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, - ] - - max_messages = 2 - messages_limiter = MessageHistoryLimiter(max_messages=max_messages) - transformed_messages = messages_limiter.apply_transform(copy.deepcopy(messages)) - - assert len(transformed_messages) == max_messages - assert transformed_messages == messages[max_messages:] - - @pytest.mark.skipif(skip_openai, reason="Requested to skip openai test.") def test_transform_messages_capability(): """Test the TransformMessages capability to handle long contexts. @@ -172,6 +70,4 @@ def test_transform_messages_capability(): if __name__ == "__main__": - test_limit_token_transform() - test_max_message_history_length_transform() test_transform_messages_capability() diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index cc4d8c905c4f..1a929e4c6ba1 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -24,6 +24,10 @@ def get_short_messages() -> List[Dict]: ] +def get_no_content_messages() -> List[Dict]: + return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] + + @pytest.fixture def message_history_limiter() -> MessageHistoryLimiter: return MessageHistoryLimiter(max_messages=3) @@ -37,10 +41,13 @@ def message_token_limiter() -> MessageTokenLimiter: # MessageHistoryLimiter tests -@pytest.mark.parametrize("messages, expected_len", [(get_long_messages(), 3), (get_short_messages(), 3)]) -def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_len): +@pytest.mark.parametrize( + "messages, expected_messages_len", + [(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2)], +) +def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len): transformed_messages = message_history_limiter.apply_transform(messages) - assert len(transformed_messages) == expected_len + assert len(transformed_messages) == expected_messages_len @pytest.mark.parametrize( @@ -48,6 +55,7 @@ def test_message_history_limiter_apply_transform(message_history_limiter, messag [ (get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True), (get_short_messages(), "No messages were removed.", False), + (get_no_content_messages(), "No messages were removed.", False), ], ) def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect): @@ -61,10 +69,18 @@ def test_message_history_limiter_get_logs(message_history_limiter, messages, exp # MessageTokenLimiter tests -@pytest.mark.parametrize("messages, expected_token_count", [(get_long_messages(), 9), (get_short_messages(), 3)]) -def test_message_token_limiter_apply_transform(message_token_limiter, messages, expected_token_count): +@pytest.mark.parametrize( + "messages, expected_token_count, expected_messages_len", + [(get_long_messages(), 9, 5), (get_short_messages(), 3, 3), (get_no_content_messages(), 0, 2)], +) +def test_message_token_limiter_apply_transform( + message_token_limiter, messages, expected_token_count, expected_messages_len +): transformed_messages = message_token_limiter.apply_transform(messages) - assert sum(_count_tokens(msg["content"]) for msg in transformed_messages) == expected_token_count + assert ( + sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count + ) + assert len(transformed_messages) == expected_messages_len @pytest.mark.parametrize( @@ -72,6 +88,7 @@ def test_message_token_limiter_apply_transform(message_token_limiter, messages, [ (get_long_messages(), "Truncated 6 tokens. Number of tokens reduced from 15 to 9", True), (get_short_messages(), "No tokens were truncated.", False), + (get_no_content_messages(), "No tokens were truncated.", False), ], ) def test_message_token_limiter_get_logs(message_token_limiter, messages, expected_logs, expected_effect):