Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Transforms Utils #2863

Merged
merged 6 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 33 additions & 81 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

import tiktoken
import transforms_util
marklysze marked this conversation as resolved.
Show resolved Hide resolved
from termcolor import colored

from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache
from autogen.oai.openai_utils import filter_config
from autogen.types import MessageContentType

from .text_compressors import LLMLingua, TextCompressor

Expand Down Expand Up @@ -169,7 +170,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
assert self._min_tokens is not None

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages

temp_messages = copy.deepcopy(messages)
Expand All @@ -178,13 +179,13 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

for msg in reversed(temp_messages):
# Some messages may not have content.
if not _is_content_right_type(msg.get("content")):
if not transforms_util.is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue

if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
processed_messages_tokens += _count_tokens(msg["content"])
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
continue

expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
Expand All @@ -199,7 +200,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
break

msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
msg_tokens = _count_tokens(msg["content"])
msg_tokens = transforms_util.count_text_tokens(msg["content"])

# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
Expand All @@ -209,10 +210,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_tokens = sum(
_count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
)
post_transform_messages_tokens = sum(
_count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
)

if post_transform_messages_tokens < pre_transform_messages_tokens:
Expand Down Expand Up @@ -349,31 +350,32 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
return messages

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages

total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not _is_content_right_type(message.get("content")):
if not transforms_util.is_content_right_type(message.get("content")):
continue

if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
continue

if _is_content_text_empty(message["content"]):
if transforms_util.is_content_text_empty(message["content"]):
continue

cached_content = self._cache_get(message["content"])
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
if cached_content is not None:
savings, compressed_content = cached_content
message["content"], savings = cached_content
else:
savings, compressed_content = self._compress(message["content"])
message["content"], savings = self._compress(message["content"])

self._cache_set(message["content"], compressed_content, savings)
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)

message["content"] = compressed_content
assert isinstance(savings, int)
total_savings += savings

self._recent_tokens_savings = total_savings
Expand All @@ -385,88 +387,38 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
else:
return "No tokens saved with text compression.", False

def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return 0, content
return content, 0

def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
tokens_saved = 0
for msg in content:
if "text" in msg:
savings, msg["text"] = self._compress_text(msg["text"])
for item in content:
if isinstance(item, dict) and "text" in item:
item["text"], savings = self._compress_text(item["text"])
tokens_saved += savings

elif isinstance(item, str):
item, savings = self._compress_text(item)
tokens_saved += savings
return tokens_saved, content

def _compress_text(self, text: str) -> Tuple[int, str]:
return content, tokens_saved

def _compress_text(self, text: str) -> Tuple[str, int]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)

savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]

return savings, compressed_text["compressed_prompt"]

def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
if self._cache:
cached_value = self._cache.get(self._cache_key(content))
if cached_value:
return cached_value

def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, compressed_content)
self._cache.set(self._cache_key(content), value)

def _cache_key(self, content: Union[str, List[Dict]]) -> str:
return f"{json.dumps(content)}_{self._min_tokens}"
return compressed_text["compressed_prompt"], savings

def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")


def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
if not min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count


def _is_content_right_type(content: Any) -> bool:
return isinstance(content, (str, list))


def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False


def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
if not filter_dict:
return True

return len(filter_config([message], filter_dict, exclude)) > 0
114 changes: 114 additions & 0 deletions autogen/agentchat/contrib/capabilities/transforms_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Any, Dict, Hashable, List, Optional, Tuple

from autogen import token_count_utils
from autogen.cache.abstract_cache_base import AbstractCache
from autogen.oai.openai_utils import filter_config
from autogen.types import MessageContentType


def cache_key(content: MessageContentType, *args: Hashable) -> str:
"""Calculates the cache key for the given message content and any other hashable args.

Args:
content (MessageContentType): The message content to calculate the cache key for.
*args: Any additional hashable args to include in the cache key.
"""
str_keys = [str(key) for key in (content, *args)]
return "".join(str_keys)


def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]:
"""Retrieves cachedd content from the cache.

Args:
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
key (str): The key to retrieve the content from.
"""
if cache:
cached_value = cache.get(key)
if cached_value:
return cached_value


def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
"""Sets content into the cache.

Args:
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
key (str): The key to set the content into.
content (MessageContentType): The message content to set into the cache.
*extra_values: Additional values to be passed to the cache.
"""
if cache:
cache_value = (content, *extra_values)
cache.set(key, cache_value)


def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.

Args:
messages (List[Dict]): A list of messages to check.
"""
if not min_tokens:
return True

messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens


def count_text_tokens(content: MessageContentType) -> int:
"""Calculates the number of text tokens in the given message content.

Args:
content (MessageContentType): The message content to calculate the number of text tokens for.
"""
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
token_count += token_count_utils.count_token(item)
else:
token_count += count_text_tokens(item.get("text", ""))
return token_count


def is_content_right_type(content: Any) -> bool:
"""A helper function to check if the passed in content is of the right type."""
return isinstance(content, (str, list))


def is_content_text_empty(content: MessageContentType) -> bool:
"""Checks if the content of the message does not contain any text.

Args:
content (MessageContentType): The message content to check.
"""
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
texts = []
for item in content:
if isinstance(item, str):
texts.append(item)
elif isinstance(item, dict):
texts.append(item.get("text", ""))
return not any(texts)
else:
return True


def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
"""Validates whether the transform should be applied according to the filter dictionary.

Args:
message (Dict[str, Any]): The message to validate.
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
exclude (bool): Whether to exclude messages that match the filter dictionary.
"""
if not filter_dict:
return True

return len(filter_config([message], filter_dict, exclude)) > 0
2 changes: 2 additions & 0 deletions autogen/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List, Literal, TypedDict, Union

MessageContentType = Union[str, List[Union[Dict, str]], None]


class UserMessageTextContentPart(TypedDict):
type: Literal["text"]
Expand Down
Loading
Loading