Skip to content

Commit 9088390

Browse files
Standardize printing of MessageTransforms (#2308)
* Standardize printing of MessageTransforms * Fix pre-commit black failure * Add test for transform_messages printing * Return str instead of printing * Rename to_print_stats to verbose * Cleanup * t i# This is a combination of 3 commits. Update requirements * Remove lazy-fixture * Avoid calling apply_transform in two code paths * Format * Replace stats with logs * Handle no content messages in TokenLimiter get_logs() * Move tests from test_transform_messages to test_transforms --------- Co-authored-by: Wael Karkoub <[email protected]>
1 parent d473dee commit 9088390

File tree

4 files changed

+185
-141
lines changed

4 files changed

+185
-141
lines changed

autogen/agentchat/contrib/capabilities/transform_messages.py

+18-22
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import copy
22
from typing import Dict, List
33

4-
from termcolor import colored
5-
64
from autogen import ConversableAgent
75

6+
from ....formatting_utils import colored
87
from .transforms import MessageTransform
98

109

@@ -43,12 +42,14 @@ class TransformMessages:
4342
```
4443
"""
4544

46-
def __init__(self, *, transforms: List[MessageTransform] = []):
45+
def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True):
4746
"""
4847
Args:
4948
transforms: A list of message transformations to apply.
49+
verbose: Whether to print logs of each transformation or not.
5050
"""
5151
self._transforms = transforms
52+
self._verbose = verbose
5253

5354
def add_to_agent(self, agent: ConversableAgent):
5455
"""Adds the message transformations capability to the specified ConversableAgent.
@@ -61,31 +62,26 @@ def add_to_agent(self, agent: ConversableAgent):
6162
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
6263

6364
def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
64-
temp_messages = copy.deepcopy(messages)
65+
post_transform_messages = copy.deepcopy(messages)
6566
system_message = None
6667

6768
if messages[0]["role"] == "system":
6869
system_message = copy.deepcopy(messages[0])
69-
temp_messages.pop(0)
70+
post_transform_messages.pop(0)
7071

7172
for transform in self._transforms:
72-
temp_messages = transform.apply_transform(temp_messages)
73-
74-
if system_message:
75-
temp_messages.insert(0, system_message)
76-
77-
self._print_stats(messages, temp_messages)
73+
# deepcopy in case pre_transform_messages will later be used for logs printing
74+
pre_transform_messages = (
75+
copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
76+
)
77+
post_transform_messages = transform.apply_transform(pre_transform_messages)
7878

79-
return temp_messages
79+
if self._verbose:
80+
logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
81+
if had_effect:
82+
print(colored(logs_str, "yellow"))
8083

81-
def _print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]):
82-
pre_transform_messages_len = len(pre_transform_messages)
83-
post_transform_messages_len = len(post_transform_messages)
84+
if system_message:
85+
post_transform_messages.insert(0, system_message)
8486

85-
if pre_transform_messages_len < post_transform_messages_len:
86-
print(
87-
colored(
88-
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.",
89-
"yellow",
90-
)
91-
)
87+
return post_transform_messages

autogen/agentchat/contrib/capabilities/transforms.py

+45-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import copy
12
import sys
2-
from typing import Any, Dict, List, Optional, Protocol, Union
3+
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
34

45
import tiktoken
56
from termcolor import colored
@@ -25,6 +26,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
2526
"""
2627
...
2728

29+
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
30+
"""Creates the string including the logs of the transformation
31+
32+
Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
33+
34+
Args:
35+
pre_transform_messages: A list of dictionaries representing messages before the transformation.
36+
post_transform_messages: A list of dictionaries representig messages after the transformation.
37+
38+
Returns:
39+
A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
40+
"""
41+
...
42+
2843

2944
class MessageHistoryLimiter:
3045
"""Limits the number of messages considered by an agent for response generation.
@@ -60,6 +75,18 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
6075

6176
return messages[-self._max_messages :]
6277

78+
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
79+
pre_transform_messages_len = len(pre_transform_messages)
80+
post_transform_messages_len = len(post_transform_messages)
81+
82+
if post_transform_messages_len < pre_transform_messages_len:
83+
logs_str = (
84+
f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
85+
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
86+
)
87+
return logs_str, True
88+
return "No messages were removed.", False
89+
6390
def _validate_max_messages(self, max_messages: Optional[int]):
6491
if max_messages is not None and max_messages < 1:
6592
raise ValueError("max_messages must be None or greater than 1")
@@ -121,15 +148,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
121148
assert self._max_tokens_per_message is not None
122149
assert self._max_tokens is not None
123150

124-
temp_messages = messages.copy()
151+
temp_messages = copy.deepcopy(messages)
125152
processed_messages = []
126153
processed_messages_tokens = 0
127154

128-
# calculate tokens for all messages
129-
total_tokens = sum(
130-
_count_tokens(msg["content"]) for msg in temp_messages if isinstance(msg.get("content"), (str, list))
131-
)
132-
133155
for msg in reversed(temp_messages):
134156
# Some messages may not have content.
135157
if not isinstance(msg.get("content"), (str, list)):
@@ -154,16 +176,24 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
154176
processed_messages_tokens += msg_tokens
155177
processed_messages.insert(0, msg)
156178

157-
if total_tokens > processed_messages_tokens:
158-
print(
159-
colored(
160-
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
161-
"yellow",
162-
)
163-
)
164-
165179
return processed_messages
166180

181+
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
182+
pre_transform_messages_tokens = sum(
183+
_count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
184+
)
185+
post_transform_messages_tokens = sum(
186+
_count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
187+
)
188+
189+
if post_transform_messages_tokens < pre_transform_messages_tokens:
190+
logs_str = (
191+
f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
192+
f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
193+
)
194+
return logs_str, True
195+
return "No tokens were truncated.", False
196+
167197
def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
168198
if isinstance(contents, str):
169199
return self._truncate_tokens(contents, n_tokens)
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import os
32
import sys
43
import tempfile
@@ -7,7 +6,6 @@
76
import pytest
87

98
import autogen
10-
from autogen import token_count_utils
119
from autogen.agentchat.contrib.capabilities.transform_messages import TransformMessages
1210
from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter
1311

@@ -18,106 +16,6 @@
1816
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
1917

2018

21-
def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
22-
token_count = 0
23-
if isinstance(content, str):
24-
token_count = token_count_utils.count_token(content)
25-
elif isinstance(content, list):
26-
for item in content:
27-
token_count += _count_tokens(item.get("text", ""))
28-
return token_count
29-
30-
31-
def test_limit_token_transform():
32-
"""
33-
Test the TokenLimitTransform capability.
34-
"""
35-
36-
messages = [
37-
{"role": "user", "content": "short string"},
38-
{
39-
"role": "assistant",
40-
"content": [{"type": "text", "text": "very very very very very very very very long string"}],
41-
},
42-
]
43-
44-
# check if token limit per message is not exceeded.
45-
max_tokens_per_message = 5
46-
token_limit_transform = MessageTokenLimiter(max_tokens_per_message=max_tokens_per_message)
47-
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))
48-
49-
for message in transformed_messages:
50-
assert _count_tokens(message["content"]) <= max_tokens_per_message
51-
52-
# check if total token limit is not exceeded.
53-
max_tokens = 10
54-
token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens)
55-
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))
56-
57-
token_count = 0
58-
for message in transformed_messages:
59-
token_count += _count_tokens(message["content"])
60-
61-
assert token_count <= max_tokens
62-
assert len(transformed_messages) <= len(messages)
63-
64-
# check if token limit per message works nicely with total token limit.
65-
token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens, max_tokens_per_message=max_tokens_per_message)
66-
67-
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))
68-
69-
token_count = 0
70-
for message in transformed_messages:
71-
token_count_local = _count_tokens(message["content"])
72-
token_count += token_count_local
73-
assert token_count_local <= max_tokens_per_message
74-
75-
assert token_count <= max_tokens
76-
assert len(transformed_messages) <= len(messages)
77-
78-
79-
def test_limit_token_transform_without_content():
80-
"""Test the TokenLimitTransform with messages that don't have content."""
81-
82-
messages = [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]
83-
84-
# check if token limit per message works nicely with total token limit.
85-
token_limit_transform = MessageTokenLimiter(max_tokens=10, max_tokens_per_message=5)
86-
87-
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))
88-
89-
assert len(transformed_messages) == len(messages)
90-
91-
92-
def test_limit_token_transform_total_token_count():
93-
"""Tests if the TokenLimitTransform truncates without dropping messages."""
94-
messages = [{"role": "very very very very very"}]
95-
96-
token_limit_transform = MessageTokenLimiter(max_tokens=1)
97-
transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages))
98-
99-
assert len(transformed_messages) == 1
100-
101-
102-
def test_max_message_history_length_transform():
103-
"""
104-
Test the MessageHistoryLimiter capability to limit the number of messages.
105-
"""
106-
messages = [
107-
{"role": "user", "content": "hello"},
108-
{"role": "assistant", "content": [{"type": "text", "text": "there"}]},
109-
{"role": "user", "content": "how"},
110-
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
111-
]
112-
113-
max_messages = 2
114-
messages_limiter = MessageHistoryLimiter(max_messages=max_messages)
115-
transformed_messages = messages_limiter.apply_transform(copy.deepcopy(messages))
116-
117-
assert len(transformed_messages) == max_messages
118-
assert transformed_messages == messages[max_messages:]
119-
120-
12119
@pytest.mark.skipif(skip_openai, reason="Requested to skip openai test.")
12220
def test_transform_messages_capability():
12321
"""Test the TransformMessages capability to handle long contexts.
@@ -172,6 +70,4 @@ def test_transform_messages_capability():
17270

17371

17472
if __name__ == "__main__":
175-
test_limit_token_transform()
176-
test_max_message_history_length_transform()
17773
test_transform_messages_capability()

0 commit comments

Comments
 (0)