diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index bc56efd74d23..dad3fc335edf 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -1,5 +1,4 @@ import copy -import json import sys from typing import Any, Dict, List, Optional, Protocol, Tuple, Union @@ -8,8 +7,9 @@ from autogen import token_count_utils from autogen.cache import AbstractCache, Cache -from autogen.oai.openai_utils import filter_config +from autogen.types import MessageContentType +from . import transforms_util from .text_compressors import LLMLingua, TextCompressor @@ -169,7 +169,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: assert self._min_tokens is not None # if the total number of tokens in the messages is less than the min_tokens, return the messages as is - if not _min_tokens_reached(messages, self._min_tokens): + if not transforms_util.min_tokens_reached(messages, self._min_tokens): return messages temp_messages = copy.deepcopy(messages) @@ -178,13 +178,13 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: for msg in reversed(temp_messages): # Some messages may not have content. - if not _is_content_right_type(msg.get("content")): + if not transforms_util.is_content_right_type(msg.get("content")): processed_messages.insert(0, msg) continue - if not _should_transform_message(msg, self._filter_dict, self._exclude_filter): + if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter): processed_messages.insert(0, msg) - processed_messages_tokens += _count_tokens(msg["content"]) + processed_messages_tokens += transforms_util.count_text_tokens(msg["content"]) continue expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message @@ -199,7 +199,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: break msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message) - msg_tokens = _count_tokens(msg["content"]) + msg_tokens = transforms_util.count_text_tokens(msg["content"]) # prepend the message to the list to preserve order processed_messages_tokens += msg_tokens @@ -209,10 +209,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: pre_transform_messages_tokens = sum( - _count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg + transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg ) post_transform_messages_tokens = sum( - _count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg + transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg ) if post_transform_messages_tokens < pre_transform_messages_tokens: @@ -349,31 +349,32 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return messages # if the total number of tokens in the messages is less than the min_tokens, return the messages as is - if not _min_tokens_reached(messages, self._min_tokens): + if not transforms_util.min_tokens_reached(messages, self._min_tokens): return messages total_savings = 0 processed_messages = messages.copy() for message in processed_messages: # Some messages may not have content. - if not _is_content_right_type(message.get("content")): + if not transforms_util.is_content_right_type(message.get("content")): continue - if not _should_transform_message(message, self._filter_dict, self._exclude_filter): + if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter): continue - if _is_content_text_empty(message["content"]): + if transforms_util.is_content_text_empty(message["content"]): continue - cached_content = self._cache_get(message["content"]) + cache_key = transforms_util.cache_key(message["content"], self._min_tokens) + cached_content = transforms_util.cache_content_get(self._cache, cache_key) if cached_content is not None: - savings, compressed_content = cached_content + message["content"], savings = cached_content else: - savings, compressed_content = self._compress(message["content"]) + message["content"], savings = self._compress(message["content"]) - self._cache_set(message["content"], compressed_content, savings) + transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings) - message["content"] = compressed_content + assert isinstance(savings, int) total_savings += savings self._recent_tokens_savings = total_savings @@ -385,24 +386,29 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: else: return "No tokens saved with text compression.", False - def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]: + def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]: """Compresses the given text or multimodal content using the specified compression method.""" if isinstance(content, str): return self._compress_text(content) elif isinstance(content, list): return self._compress_multimodal(content) else: - return 0, content + return content, 0 - def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]: + def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]: tokens_saved = 0 - for msg in content: - if "text" in msg: - savings, msg["text"] = self._compress_text(msg["text"]) + for item in content: + if isinstance(item, dict) and "text" in item: + item["text"], savings = self._compress_text(item["text"]) + tokens_saved += savings + + elif isinstance(item, str): + item, savings = self._compress_text(item) tokens_saved += savings - return tokens_saved, content - def _compress_text(self, text: str) -> Tuple[int, str]: + return content, tokens_saved + + def _compress_text(self, text: str) -> Tuple[str, int]: """Compresses the given text using the specified compression method.""" compressed_text = self._text_compressor.compress_text(text, **self._compression_args) @@ -410,63 +416,8 @@ def _compress_text(self, text: str) -> Tuple[int, str]: if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text: savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"] - return savings, compressed_text["compressed_prompt"] - - def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]: - if self._cache: - cached_value = self._cache.get(self._cache_key(content)) - if cached_value: - return cached_value - - def _cache_set( - self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int - ): - if self._cache: - value = (tokens_saved, compressed_content) - self._cache.set(self._cache_key(content), value) - - def _cache_key(self, content: Union[str, List[Dict]]) -> str: - return f"{json.dumps(content)}_{self._min_tokens}" + return compressed_text["compressed_prompt"], savings def _validate_min_tokens(self, min_tokens: Optional[int]): if min_tokens is not None and min_tokens <= 0: raise ValueError("min_tokens must be greater than 0 or None") - - -def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool: - """Returns True if the total number of tokens in the messages is greater than or equal to the specified value.""" - if not min_tokens: - return True - - messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg) - return messages_tokens >= min_tokens - - -def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int: - token_count = 0 - if isinstance(content, str): - token_count = token_count_utils.count_token(content) - elif isinstance(content, list): - for item in content: - token_count += _count_tokens(item.get("text", "")) - return token_count - - -def _is_content_right_type(content: Any) -> bool: - return isinstance(content, (str, list)) - - -def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool: - if isinstance(content, str): - return content == "" - elif isinstance(content, list): - return all(_is_content_text_empty(item.get("text", "")) for item in content) - else: - return False - - -def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool: - if not filter_dict: - return True - - return len(filter_config([message], filter_dict, exclude)) > 0 diff --git a/autogen/agentchat/contrib/capabilities/transforms_util.py b/autogen/agentchat/contrib/capabilities/transforms_util.py new file mode 100644 index 000000000000..8678dec654c4 --- /dev/null +++ b/autogen/agentchat/contrib/capabilities/transforms_util.py @@ -0,0 +1,114 @@ +from typing import Any, Dict, Hashable, List, Optional, Tuple + +from autogen import token_count_utils +from autogen.cache.abstract_cache_base import AbstractCache +from autogen.oai.openai_utils import filter_config +from autogen.types import MessageContentType + + +def cache_key(content: MessageContentType, *args: Hashable) -> str: + """Calculates the cache key for the given message content and any other hashable args. + + Args: + content (MessageContentType): The message content to calculate the cache key for. + *args: Any additional hashable args to include in the cache key. + """ + str_keys = [str(key) for key in (content, *args)] + return "".join(str_keys) + + +def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]: + """Retrieves cachedd content from the cache. + + Args: + cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored. + key (str): The key to retrieve the content from. + """ + if cache: + cached_value = cache.get(key) + if cached_value: + return cached_value + + +def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values): + """Sets content into the cache. + + Args: + cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored. + key (str): The key to set the content into. + content (MessageContentType): The message content to set into the cache. + *extra_values: Additional values to be passed to the cache. + """ + if cache: + cache_value = (content, *extra_values) + cache.set(key, cache_value) + + +def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool: + """Returns True if the total number of tokens in the messages is greater than or equal to the specified value. + + Args: + messages (List[Dict]): A list of messages to check. + """ + if not min_tokens: + return True + + messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg) + return messages_tokens >= min_tokens + + +def count_text_tokens(content: MessageContentType) -> int: + """Calculates the number of text tokens in the given message content. + + Args: + content (MessageContentType): The message content to calculate the number of text tokens for. + """ + token_count = 0 + if isinstance(content, str): + token_count = token_count_utils.count_token(content) + elif isinstance(content, list): + for item in content: + if isinstance(item, str): + token_count += token_count_utils.count_token(item) + else: + token_count += count_text_tokens(item.get("text", "")) + return token_count + + +def is_content_right_type(content: Any) -> bool: + """A helper function to check if the passed in content is of the right type.""" + return isinstance(content, (str, list)) + + +def is_content_text_empty(content: MessageContentType) -> bool: + """Checks if the content of the message does not contain any text. + + Args: + content (MessageContentType): The message content to check. + """ + if isinstance(content, str): + return content == "" + elif isinstance(content, list): + texts = [] + for item in content: + if isinstance(item, str): + texts.append(item) + elif isinstance(item, dict): + texts.append(item.get("text", "")) + return not any(texts) + else: + return True + + +def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool: + """Validates whether the transform should be applied according to the filter dictionary. + + Args: + message (Dict[str, Any]): The message to validate. + filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied. + exclude (bool): Whether to exclude messages that match the filter dictionary. + """ + if not filter_dict: + return True + + return len(filter_config([message], filter_dict, exclude)) > 0 diff --git a/autogen/types.py b/autogen/types.py index 77ca70b70b97..461765a6adcd 100644 --- a/autogen/types.py +++ b/autogen/types.py @@ -1,5 +1,7 @@ from typing import Dict, List, Literal, TypedDict, Union +MessageContentType = Union[str, List[Union[Dict, str]], None] + class UserMessageTextContentPart(TypedDict): type: Literal["text"] diff --git a/test/agentchat/contrib/capabilities/test_transforms_util.py b/test/agentchat/contrib/capabilities/test_transforms_util.py new file mode 100644 index 000000000000..089ebbdc8db2 --- /dev/null +++ b/test/agentchat/contrib/capabilities/test_transforms_util.py @@ -0,0 +1,72 @@ +import itertools +import tempfile +from typing import Dict, Tuple + +import pytest + +from autogen.agentchat.contrib.capabilities import transforms_util +from autogen.cache.cache import Cache +from autogen.types import MessageContentType + +MESSAGES = { + "message1": { + "content": [{"text": "Hello"}, {"image_url": {"url": "https://example.com/image.jpg"}}], + "text_tokens": 1, + }, + "message2": {"content": [{"image_url": {"url": "https://example.com/image.jpg"}}], "text_tokens": 0}, + "message3": {"content": [{"text": "Hello"}, {"text": "World"}], "text_tokens": 2}, + "message4": {"content": None, "text_tokens": 0}, + "message5": {"content": "Hello there!", "text_tokens": 3}, + "message6": {"content": ["Hello there!", "Hello there!"], "text_tokens": 6}, +} + + +@pytest.mark.parametrize("message", MESSAGES.values()) +def test_cache_content(message: Dict[str, MessageContentType]) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + cache = Cache.disk(tmpdirname) + cache_key_1 = "test_string" + + transforms_util.cache_content_set(cache, cache_key_1, message["content"]) + assert transforms_util.cache_content_get(cache, cache_key_1) == (message["content"],) + + cache_key_2 = "test_list" + cache_value_2 = [message["content"], 1, "some_string", {"new_key": "new_value"}] + transforms_util.cache_content_set(cache, cache_key_2, *cache_value_2) + assert transforms_util.cache_content_get(cache, cache_key_2) == tuple(cache_value_2) + assert isinstance(cache_value_2[1], int) + assert isinstance(cache_value_2[2], str) + assert isinstance(cache_value_2[3], dict) + + cache_key_3 = "test_None" + transforms_util.cache_content_set(None, cache_key_3, message["content"]) + assert transforms_util.cache_content_get(cache, cache_key_3) is None + assert transforms_util.cache_content_get(None, cache_key_3) is None + + +@pytest.mark.parametrize("messages", itertools.product(MESSAGES.values(), MESSAGES.values())) +def test_cache_key(messages: Tuple[Dict[str, MessageContentType], Dict[str, MessageContentType]]) -> None: + message_1, message_2 = messages + cache_1 = transforms_util.cache_key(message_1["content"], 10) + cache_2 = transforms_util.cache_key(message_2["content"], 10) + if message_1 == message_2: + assert cache_1 == cache_2 + else: + assert cache_1 != cache_2 + + +@pytest.mark.parametrize("message", MESSAGES.values()) +def test_min_tokens_reached(message: Dict[str, MessageContentType]): + assert transforms_util.min_tokens_reached([message], None) + assert transforms_util.min_tokens_reached([message], 0) + assert not transforms_util.min_tokens_reached([message], message["text_tokens"] + 1) + + +@pytest.mark.parametrize("message", MESSAGES.values()) +def test_count_text_tokens(message: Dict[str, MessageContentType]): + assert transforms_util.count_text_tokens(message["content"]) == message["text_tokens"] + + +@pytest.mark.parametrize("message", MESSAGES.values()) +def test_is_content_text_empty(message: Dict[str, MessageContentType]): + assert transforms_util.is_content_text_empty(message["content"]) == (message["text_tokens"] == 0)