diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 8303843e8818..bc56efd74d23 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -8,6 +8,7 @@ from autogen import token_count_utils from autogen.cache import AbstractCache, Cache +from autogen.oai.openai_utils import filter_config from .text_compressors import LLMLingua, TextCompressor @@ -130,6 +131,8 @@ def __init__( max_tokens: Optional[int] = None, min_tokens: Optional[int] = None, model: str = "gpt-3.5-turbo-0613", + filter_dict: Optional[Dict] = None, + exclude_filter: bool = True, ): """ Args: @@ -140,11 +143,17 @@ def __init__( min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation. Must be greater than or equal to 0 if not None. model (str): The target OpenAI model for tokenization alignment. + filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress. + If None, no filters will be applied. + exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be + excluded from token truncation. If False, messages that match the filter will be truncated. """ self._model = model self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message) self._max_tokens = self._validate_max_tokens(max_tokens) self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens) + self._filter_dict = filter_dict + self._exclude_filter = exclude_filter def apply_transform(self, messages: List[Dict]) -> List[Dict]: """Applies token truncation to the conversation history. @@ -169,10 +178,15 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: for msg in reversed(temp_messages): # Some messages may not have content. - if not isinstance(msg.get("content"), (str, list)): + if not _is_content_right_type(msg.get("content")): processed_messages.insert(0, msg) continue + if not _should_transform_message(msg, self._filter_dict, self._exclude_filter): + processed_messages.insert(0, msg) + processed_messages_tokens += _count_tokens(msg["content"]) + continue + expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message # If adding this message would exceed the token limit, truncate the last message to meet the total token @@ -282,6 +296,8 @@ def __init__( min_tokens: Optional[int] = None, compression_params: Dict = dict(), cache: Optional[AbstractCache] = Cache.disk(), + filter_dict: Optional[Dict] = None, + exclude_filter: bool = True, ): """ Args: @@ -293,6 +309,10 @@ def __init__( dictionary. cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages. If None, no caching will be used. + filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress. + If None, no filters will be applied. + exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be + excluded from compression. If False, messages that match the filter will be compressed. """ if text_compressor is None: @@ -303,6 +323,8 @@ def __init__( self._text_compressor = text_compressor self._min_tokens = min_tokens self._compression_args = compression_params + self._filter_dict = filter_dict + self._exclude_filter = exclude_filter self._cache = cache # Optimizing savings calculations to optimize log generation @@ -334,7 +356,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: processed_messages = messages.copy() for message in processed_messages: # Some messages may not have content. - if not isinstance(message.get("content"), (str, list)): + if not _is_content_right_type(message.get("content")): + continue + + if not _should_transform_message(message, self._filter_dict, self._exclude_filter): continue if _is_content_text_empty(message["content"]): @@ -397,7 +422,7 @@ def _cache_set( self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int ): if self._cache: - value = (tokens_saved, json.dumps(compressed_content)) + value = (tokens_saved, compressed_content) self._cache.set(self._cache_key(content), value) def _cache_key(self, content: Union[str, List[Dict]]) -> str: @@ -427,6 +452,10 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int: return token_count +def _is_content_right_type(content: Any) -> bool: + return isinstance(content, (str, list)) + + def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool: if isinstance(content, str): return content == "" @@ -434,3 +463,10 @@ def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool: return all(_is_content_text_empty(item.get("text", "")) for item in content) else: return False + + +def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool: + if not filter_dict: + return True + + return len(filter_config([message], filter_dict, exclude)) > 0 diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index f8dad6d79841..1ed347f6271c 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -379,11 +379,10 @@ def config_list_gpt4_gpt35( def filter_config( config_list: List[Dict[str, Any]], filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]], + exclude: bool = False, ) -> List[Dict[str, Any]]: - """ - This function filters `config_list` by checking each configuration dictionary against the - criteria specified in `filter_dict`. A configuration dictionary is retained if for every - key in `filter_dict`, see example below. + """This function filters `config_list` by checking each configuration dictionary against the criteria specified in + `filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below. Args: config_list (list of dict): A list of configuration dictionaries to be filtered. @@ -394,71 +393,68 @@ def filter_config( when it is found in the list of acceptable values. If the configuration's field's value is a list, then a match occurs if there is a non-empty intersection with the acceptable values. - - + exclude (bool): If False (the default value), configs that match the filter will be included in the returned + list. If True, configs that match the filter will be excluded in the returned list. Returns: list of dict: A list of configuration dictionaries that meet all the criteria specified in `filter_dict`. Example: - ```python - # Example configuration list with various models and API types - configs = [ - {'model': 'gpt-3.5-turbo'}, - {'model': 'gpt-4'}, - {'model': 'gpt-3.5-turbo', 'api_type': 'azure'}, - {'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}, - ] - - # Define filter criteria to select configurations for the 'gpt-3.5-turbo' model - # that are also using the 'azure' API type - filter_criteria = { - 'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo' - 'api_type': ['azure'] # Only accept configurations for 'azure' API type - } - - # Apply the filter to the configuration list - filtered_configs = filter_config(configs, filter_criteria) - - # The resulting `filtered_configs` will be: - # [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}] - - - # Define a filter to select a given tag - filter_criteria = { - 'tags': ['gpt35_turbo'], - } - - # Apply the filter to the configuration list - filtered_configs = filter_config(configs, filter_criteria) - - # The resulting `filtered_configs` will be: - # [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}] - ``` - + ```python + # Example configuration list with various models and API types + configs = [ + {'model': 'gpt-3.5-turbo'}, + {'model': 'gpt-4'}, + {'model': 'gpt-3.5-turbo', 'api_type': 'azure'}, + {'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}, + ] + # Define filter criteria to select configurations for the 'gpt-3.5-turbo' model + # that are also using the 'azure' API type + filter_criteria = { + 'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo' + 'api_type': ['azure'] # Only accept configurations for 'azure' API type + } + # Apply the filter to the configuration list + filtered_configs = filter_config(configs, filter_criteria) + # The resulting `filtered_configs` will be: + # [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}] + # Define a filter to select a given tag + filter_criteria = { + 'tags': ['gpt35_turbo'], + } + # Apply the filter to the configuration list + filtered_configs = filter_config(configs, filter_criteria) + # The resulting `filtered_configs` will be: + # [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}] + ``` Note: - If `filter_dict` is empty or None, no filtering is applied and `config_list` is returned as is. - If a configuration dictionary in `config_list` does not contain a key specified in `filter_dict`, it is considered a non-match and is excluded from the result. - If the list of acceptable values for a key in `filter_dict` includes None, then configuration dictionaries that do not have that key will also be considered a match. - """ - def _satisfies(config_value: Any, acceptable_values: Any) -> bool: - if isinstance(config_value, list): - return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection - else: - return config_value in acceptable_values + """ if filter_dict: - config_list = [ - config - for config in config_list - if all(_satisfies(config.get(key), value) for key, value in filter_dict.items()) + return [ + item + for item in config_list + if all(_satisfies_criteria(item.get(key), values) != exclude for key, values in filter_dict.items()) ] return config_list +def _satisfies_criteria(value: Any, criteria_values: Any) -> bool: + if value is None: + return False + + if isinstance(value, list): + return bool(set(value) & set(criteria_values)) # Non-empty intersection + else: + return value in criteria_values + + def config_list_from_json( env_or_file: str, file_location: Optional[str] = "", @@ -785,3 +781,10 @@ def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Di assistant_update_kwargs["file_ids"] = assistant_config["file_ids"] return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs) + + +def _satisfies(config_value: Any, acceptable_values: Any) -> bool: + if isinstance(config_value, list): + return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection + else: + return config_value in acceptable_values diff --git a/test/agentchat/contrib/capabilities/test_transforms.py b/test/agentchat/contrib/capabilities/test_transforms.py index c5ffc08f112b..46c61d9adc6f 100644 --- a/test/agentchat/contrib/capabilities/test_transforms.py +++ b/test/agentchat/contrib/capabilities/test_transforms.py @@ -1,10 +1,21 @@ import copy -from typing import Dict, List +from typing import Any, Dict, List from unittest.mock import MagicMock, patch import pytest -from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter, _count_tokens +from autogen.agentchat.contrib.capabilities.text_compressors import TextCompressor +from autogen.agentchat.contrib.capabilities.transforms import ( + MessageHistoryLimiter, + MessageTokenLimiter, + TextMessageCompressor, + _count_tokens, +) + + +class _MockTextCompressor: + def compress_text(self, text: str, **compression_params) -> Dict[str, Any]: + return {"compressed_prompt": ""} def get_long_messages() -> List[Dict]: @@ -29,6 +40,18 @@ def get_no_content_messages() -> List[Dict]: return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] +def get_text_compressors() -> List[TextCompressor]: + compressors: List[TextCompressor] = [_MockTextCompressor()] + try: + from autogen.agentchat.contrib.capabilities.text_compressors import LLMLingua + + compressors.append(LLMLingua()) + except ImportError: + pass + + return compressors + + @pytest.fixture def message_history_limiter() -> MessageHistoryLimiter: return MessageHistoryLimiter(max_messages=3) @@ -44,6 +67,30 @@ def message_token_limiter_with_threshold() -> MessageTokenLimiter: return MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10) +def _filter_dict_test( + post_transformed_message: Dict, pre_transformed_messages: Dict, roles: List[str], exclude_filter: bool +) -> bool: + is_role = post_transformed_message["role"] in roles + if exclude_filter: + is_role = not is_role + + if isinstance(post_transformed_message["content"], list): + condition = ( + len(post_transformed_message["content"][0]["text"]) < len(pre_transformed_messages["content"][0]["text"]) + if is_role + else len(post_transformed_message["content"][0]["text"]) + == len(pre_transformed_messages["content"][0]["text"]) + ) + else: + condition = ( + len(post_transformed_message["content"]) < len(pre_transformed_messages["content"]) + if is_role + else len(post_transformed_message["content"]) == len(pre_transformed_messages["content"]) + ) + + return condition + + # MessageHistoryLimiter @@ -82,13 +129,35 @@ def test_message_history_limiter_get_logs(message_history_limiter, messages, exp def test_message_token_limiter_apply_transform( message_token_limiter, messages, expected_token_count, expected_messages_len ): - transformed_messages = message_token_limiter.apply_transform(messages) + transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages)) assert ( sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count ) assert len(transformed_messages) == expected_messages_len +@pytest.mark.parametrize("messages", [get_long_messages(), get_short_messages()]) +def test_message_token_limiter_with_filter(messages): + # Test truncating all messages except for user + message_token_limiter = MessageTokenLimiter(max_tokens_per_message=0, filter_dict={"role": "user"}) + transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages)) + + pre_post_messages = zip(messages, transformed_messages) + + for pre_transform, post_transform in pre_post_messages: + assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=True) + + # Test truncating all user messages only + message_token_limiter = MessageTokenLimiter( + max_tokens_per_message=0, filter_dict={"role": "user"}, exclude_filter=False + ) + transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages)) + + pre_post_messages = zip(messages, transformed_messages) + for pre_transform, post_transform in pre_post_messages: + assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False) + + @pytest.mark.parametrize( "messages, expected_token_count, expected_messages_len", [(get_long_messages(), 5, 5), (get_short_messages(), 5, 3), (get_no_content_messages(), 0, 2)], @@ -119,49 +188,60 @@ def test_message_token_limiter_get_logs(message_token_limiter, messages, expecte assert logs_str == expected_logs -def test_text_compression(): +# TextMessageCompressor tests + + +@pytest.mark.parametrize("text_compressor", get_text_compressors()) +def test_text_compression(text_compressor): """Test the TextMessageCompressor transform.""" - try: - from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor - text_compressor = TextMessageCompressor() - except ImportError: - pytest.skip("LLM Lingua is not installed.") + compressor = TextMessageCompressor(text_compressor=text_compressor) text = "Run this test with a long string. " messages = [ - { - "role": "assistant", - "content": [{"type": "text", "text": "".join([text] * 3)}], - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "".join([text] * 3)}], - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "".join([text] * 3)}], - }, + {"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]}, + {"role": "role", "content": [{"type": "text", "text": "".join([text] * 3)}]}, + {"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]}, + {"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]}, ] - transformed_messages = text_compressor.apply_transform([{"content": text}]) + transformed_messages = compressor.apply_transform([{"content": text}]) assert len(transformed_messages[0]["content"]) < len(text) # Test compressing all messages - text_compressor = TextMessageCompressor() - transformed_messages = text_compressor.apply_transform(copy.deepcopy(messages)) - for message in transformed_messages: - assert len(message["content"][0]["text"]) < len(messages[0]["content"][0]["text"]) + compressor = TextMessageCompressor(text_compressor=text_compressor) + transformed_messages = compressor.apply_transform(copy.deepcopy(messages)) + pre_post_messages = zip(messages, transformed_messages) + for pre_transform, post_transform in pre_post_messages: + assert len(post_transform["content"][0]["text"]) < len(pre_transform["content"][0]["text"]) -def test_text_compression_cache(): - try: - from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor - except ImportError: - pytest.skip("LLM Lingua is not installed.") +@pytest.mark.parametrize("messages", [get_long_messages(), get_short_messages()]) +@pytest.mark.parametrize("text_compressor", get_text_compressors()) +def test_text_compression_with_filter(messages, text_compressor): + # Test truncating all messages except for user + compressor = TextMessageCompressor(text_compressor=text_compressor, filter_dict={"role": "user"}) + transformed_messages = compressor.apply_transform(copy.deepcopy(messages)) + + pre_post_messages = zip(messages, transformed_messages) + for pre_transform, post_transform in pre_post_messages: + assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=True) + + # Test truncating all user messages only + compressor = TextMessageCompressor( + text_compressor=text_compressor, filter_dict={"role": "user"}, exclude_filter=False + ) + transformed_messages = compressor.apply_transform(copy.deepcopy(messages)) + + pre_post_messages = zip(messages, transformed_messages) + for pre_transform, post_transform in pre_post_messages: + assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False) + +@pytest.mark.parametrize("text_compressor", get_text_compressors()) +def test_text_compression_cache(text_compressor): messages = get_long_messages() mock_compressed_content = (1, {"content": "mock"}) @@ -171,18 +251,18 @@ def test_text_compression_cache(): ) as mocked_get, patch( "autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock() ) as mocked_set: - text_compressor = TextMessageCompressor() + compressor = TextMessageCompressor(text_compressor=text_compressor) - text_compressor.apply_transform(messages) - text_compressor.apply_transform(messages) + compressor.apply_transform(messages) + compressor.apply_transform(messages) assert mocked_get.call_count == len(messages) assert mocked_set.call_count == len(messages) # We already populated the cache with the mock content # We need to test if we retrieve the correct content - text_compressor = TextMessageCompressor() - compressed_messages = text_compressor.apply_transform(messages) + compressor = TextMessageCompressor(text_compressor=text_compressor) + compressed_messages = compressor.apply_transform(messages) for message in compressed_messages: assert message["content"] == mock_compressed_content[1] diff --git a/test/oai/test_utils.py b/test/oai/test_utils.py index 8f4dd22145f2..d5ad84d8355d 100755 --- a/test/oai/test_utils.py +++ b/test/oai/test_utils.py @@ -4,6 +4,7 @@ import logging import os import tempfile +from typing import Dict, List from unittest import mock from unittest.mock import patch @@ -43,11 +44,13 @@ [ { "model": "gpt-3.5-turbo", - "api_type": "openai" + "api_type": "openai", + "tags": ["gpt35"] }, { "model": "gpt-4", - "api_type": "openai" + "api_type": "openai", + "tags": ["gpt4"] }, { "model": "gpt-35-turbo-v0301", @@ -65,6 +68,33 @@ ] """ +JSON_SAMPLE_DICT = json.loads(JSON_SAMPLE) + + +FILTER_CONFIG_TEST = [ + { + "filter_dict": {"tags": ["gpt35", "gpt4"]}, + "exclude": False, + "expected": JSON_SAMPLE_DICT[0:2], + }, + { + "filter_dict": {"tags": ["gpt35", "gpt4"]}, + "exclude": True, + "expected": JSON_SAMPLE_DICT[2:4], + }, + { + "filter_dict": {"api_type": "azure", "api_version": "2024-02-15-preview"}, + "exclude": False, + "expected": [JSON_SAMPLE_DICT[2]], + }, +] + + +def _compare_lists_of_dicts(list1: List[Dict], list2: List[Dict]) -> bool: + dump1 = sorted(json.dumps(d, sort_keys=True) for d in list1) + dump2 = sorted(json.dumps(d, sort_keys=True) for d in list2) + return dump1 == dump2 + @pytest.fixture def mock_os_environ(): @@ -72,6 +102,17 @@ def mock_os_environ(): yield +@pytest.mark.parametrize("test_case", FILTER_CONFIG_TEST) +def test_filter_config(test_case): + filter_dict = test_case["filter_dict"] + exclude = test_case["exclude"] + expected = test_case["expected"] + + config_list = filter_config(JSON_SAMPLE_DICT, filter_dict, exclude) + + assert _compare_lists_of_dicts(config_list, expected) + + def test_config_list_from_json(): with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp_file: json_data = json.loads(JSON_SAMPLE)