Skip to content

Commit 8564bd4

Browse files
authored
[Refactor] Transforms Utils (#2863)
* wip * tests + docstrings * improves tests * fix import
1 parent 102d36d commit 8564bd4

File tree

4 files changed

+221
-82
lines changed

4 files changed

+221
-82
lines changed

autogen/agentchat/contrib/capabilities/transforms.py

+33-82
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
import json
32
import sys
43
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
54

@@ -8,8 +7,9 @@
87

98
from autogen import token_count_utils
109
from autogen.cache import AbstractCache, Cache
11-
from autogen.oai.openai_utils import filter_config
10+
from autogen.types import MessageContentType
1211

12+
from . import transforms_util
1313
from .text_compressors import LLMLingua, TextCompressor
1414

1515

@@ -169,7 +169,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
169169
assert self._min_tokens is not None
170170

171171
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
172-
if not _min_tokens_reached(messages, self._min_tokens):
172+
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
173173
return messages
174174

175175
temp_messages = copy.deepcopy(messages)
@@ -178,13 +178,13 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
178178

179179
for msg in reversed(temp_messages):
180180
# Some messages may not have content.
181-
if not _is_content_right_type(msg.get("content")):
181+
if not transforms_util.is_content_right_type(msg.get("content")):
182182
processed_messages.insert(0, msg)
183183
continue
184184

185-
if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
185+
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
186186
processed_messages.insert(0, msg)
187-
processed_messages_tokens += _count_tokens(msg["content"])
187+
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
188188
continue
189189

190190
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
@@ -199,7 +199,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
199199
break
200200

201201
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
202-
msg_tokens = _count_tokens(msg["content"])
202+
msg_tokens = transforms_util.count_text_tokens(msg["content"])
203203

204204
# prepend the message to the list to preserve order
205205
processed_messages_tokens += msg_tokens
@@ -209,10 +209,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
209209

210210
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
211211
pre_transform_messages_tokens = sum(
212-
_count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
212+
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
213213
)
214214
post_transform_messages_tokens = sum(
215-
_count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
215+
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
216216
)
217217

218218
if post_transform_messages_tokens < pre_transform_messages_tokens:
@@ -349,31 +349,32 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
349349
return messages
350350

351351
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
352-
if not _min_tokens_reached(messages, self._min_tokens):
352+
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
353353
return messages
354354

355355
total_savings = 0
356356
processed_messages = messages.copy()
357357
for message in processed_messages:
358358
# Some messages may not have content.
359-
if not _is_content_right_type(message.get("content")):
359+
if not transforms_util.is_content_right_type(message.get("content")):
360360
continue
361361

362-
if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
362+
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
363363
continue
364364

365-
if _is_content_text_empty(message["content"]):
365+
if transforms_util.is_content_text_empty(message["content"]):
366366
continue
367367

368-
cached_content = self._cache_get(message["content"])
368+
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
369+
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
369370
if cached_content is not None:
370-
savings, compressed_content = cached_content
371+
message["content"], savings = cached_content
371372
else:
372-
savings, compressed_content = self._compress(message["content"])
373+
message["content"], savings = self._compress(message["content"])
373374

374-
self._cache_set(message["content"], compressed_content, savings)
375+
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
375376

376-
message["content"] = compressed_content
377+
assert isinstance(savings, int)
377378
total_savings += savings
378379

379380
self._recent_tokens_savings = total_savings
@@ -385,88 +386,38 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
385386
else:
386387
return "No tokens saved with text compression.", False
387388

388-
def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
389+
def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
389390
"""Compresses the given text or multimodal content using the specified compression method."""
390391
if isinstance(content, str):
391392
return self._compress_text(content)
392393
elif isinstance(content, list):
393394
return self._compress_multimodal(content)
394395
else:
395-
return 0, content
396+
return content, 0
396397

397-
def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
398+
def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
398399
tokens_saved = 0
399-
for msg in content:
400-
if "text" in msg:
401-
savings, msg["text"] = self._compress_text(msg["text"])
400+
for item in content:
401+
if isinstance(item, dict) and "text" in item:
402+
item["text"], savings = self._compress_text(item["text"])
403+
tokens_saved += savings
404+
405+
elif isinstance(item, str):
406+
item, savings = self._compress_text(item)
402407
tokens_saved += savings
403-
return tokens_saved, content
404408

405-
def _compress_text(self, text: str) -> Tuple[int, str]:
409+
return content, tokens_saved
410+
411+
def _compress_text(self, text: str) -> Tuple[str, int]:
406412
"""Compresses the given text using the specified compression method."""
407413
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
408414

409415
savings = 0
410416
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
411417
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
412418

413-
return savings, compressed_text["compressed_prompt"]
414-
415-
def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
416-
if self._cache:
417-
cached_value = self._cache.get(self._cache_key(content))
418-
if cached_value:
419-
return cached_value
420-
421-
def _cache_set(
422-
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
423-
):
424-
if self._cache:
425-
value = (tokens_saved, compressed_content)
426-
self._cache.set(self._cache_key(content), value)
427-
428-
def _cache_key(self, content: Union[str, List[Dict]]) -> str:
429-
return f"{json.dumps(content)}_{self._min_tokens}"
419+
return compressed_text["compressed_prompt"], savings
430420

431421
def _validate_min_tokens(self, min_tokens: Optional[int]):
432422
if min_tokens is not None and min_tokens <= 0:
433423
raise ValueError("min_tokens must be greater than 0 or None")
434-
435-
436-
def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
437-
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
438-
if not min_tokens:
439-
return True
440-
441-
messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
442-
return messages_tokens >= min_tokens
443-
444-
445-
def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
446-
token_count = 0
447-
if isinstance(content, str):
448-
token_count = token_count_utils.count_token(content)
449-
elif isinstance(content, list):
450-
for item in content:
451-
token_count += _count_tokens(item.get("text", ""))
452-
return token_count
453-
454-
455-
def _is_content_right_type(content: Any) -> bool:
456-
return isinstance(content, (str, list))
457-
458-
459-
def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
460-
if isinstance(content, str):
461-
return content == ""
462-
elif isinstance(content, list):
463-
return all(_is_content_text_empty(item.get("text", "")) for item in content)
464-
else:
465-
return False
466-
467-
468-
def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
469-
if not filter_dict:
470-
return True
471-
472-
return len(filter_config([message], filter_dict, exclude)) > 0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from typing import Any, Dict, Hashable, List, Optional, Tuple
2+
3+
from autogen import token_count_utils
4+
from autogen.cache.abstract_cache_base import AbstractCache
5+
from autogen.oai.openai_utils import filter_config
6+
from autogen.types import MessageContentType
7+
8+
9+
def cache_key(content: MessageContentType, *args: Hashable) -> str:
10+
"""Calculates the cache key for the given message content and any other hashable args.
11+
12+
Args:
13+
content (MessageContentType): The message content to calculate the cache key for.
14+
*args: Any additional hashable args to include in the cache key.
15+
"""
16+
str_keys = [str(key) for key in (content, *args)]
17+
return "".join(str_keys)
18+
19+
20+
def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]:
21+
"""Retrieves cachedd content from the cache.
22+
23+
Args:
24+
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
25+
key (str): The key to retrieve the content from.
26+
"""
27+
if cache:
28+
cached_value = cache.get(key)
29+
if cached_value:
30+
return cached_value
31+
32+
33+
def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
34+
"""Sets content into the cache.
35+
36+
Args:
37+
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
38+
key (str): The key to set the content into.
39+
content (MessageContentType): The message content to set into the cache.
40+
*extra_values: Additional values to be passed to the cache.
41+
"""
42+
if cache:
43+
cache_value = (content, *extra_values)
44+
cache.set(key, cache_value)
45+
46+
47+
def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
48+
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
49+
50+
Args:
51+
messages (List[Dict]): A list of messages to check.
52+
"""
53+
if not min_tokens:
54+
return True
55+
56+
messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
57+
return messages_tokens >= min_tokens
58+
59+
60+
def count_text_tokens(content: MessageContentType) -> int:
61+
"""Calculates the number of text tokens in the given message content.
62+
63+
Args:
64+
content (MessageContentType): The message content to calculate the number of text tokens for.
65+
"""
66+
token_count = 0
67+
if isinstance(content, str):
68+
token_count = token_count_utils.count_token(content)
69+
elif isinstance(content, list):
70+
for item in content:
71+
if isinstance(item, str):
72+
token_count += token_count_utils.count_token(item)
73+
else:
74+
token_count += count_text_tokens(item.get("text", ""))
75+
return token_count
76+
77+
78+
def is_content_right_type(content: Any) -> bool:
79+
"""A helper function to check if the passed in content is of the right type."""
80+
return isinstance(content, (str, list))
81+
82+
83+
def is_content_text_empty(content: MessageContentType) -> bool:
84+
"""Checks if the content of the message does not contain any text.
85+
86+
Args:
87+
content (MessageContentType): The message content to check.
88+
"""
89+
if isinstance(content, str):
90+
return content == ""
91+
elif isinstance(content, list):
92+
texts = []
93+
for item in content:
94+
if isinstance(item, str):
95+
texts.append(item)
96+
elif isinstance(item, dict):
97+
texts.append(item.get("text", ""))
98+
return not any(texts)
99+
else:
100+
return True
101+
102+
103+
def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
104+
"""Validates whether the transform should be applied according to the filter dictionary.
105+
106+
Args:
107+
message (Dict[str, Any]): The message to validate.
108+
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
109+
exclude (bool): Whether to exclude messages that match the filter dictionary.
110+
"""
111+
if not filter_dict:
112+
return True
113+
114+
return len(filter_config([message], filter_dict, exclude)) > 0

autogen/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Dict, List, Literal, TypedDict, Union
22

3+
MessageContentType = Union[str, List[Union[Dict, str]], None]
4+
35

46
class UserMessageTextContentPart(TypedDict):
57
type: Literal["text"]

0 commit comments

Comments
 (0)