Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix message history limiter for tool call #3178

Merged
merged 10 commits into from
Aug 9, 2024
30 changes: 27 additions & 3 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ class MessageHistoryLimiter:
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""

def __init__(self, max_messages: Optional[int] = None):
def __init__(self, max_messages: Optional[int] = None, keep_first_message: bool = False):
marklysze marked this conversation as resolved.
Show resolved Hide resolved
"""
Args:
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
keep_first_message bool: Whether to keep the original first message in the conversation history.
Defaults to False.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
self._keep_first_message = keep_first_message

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Truncates the conversation history to the specified maximum number of messages.
Expand All @@ -75,10 +78,31 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""

if self._max_messages is None:
if self._max_messages is None or len(messages) <= self._max_messages:
return messages

return messages[-self._max_messages :]
truncated_messages = []
remaining_count = self._max_messages

# Start with the first message if we need to keep it
if self._keep_first_message:
truncated_messages = [messages[0]]
remaining_count -= 1

# Loop through messages in reverse
for i in range(len(messages) - 1, 0, -1):
if remaining_count > 1:
truncated_messages.insert(1 if self._keep_first_message else 0, messages[i])
if remaining_count == 1:
# If there's only 1 slot left and it's a 'tools' message, ignore it.
if messages[i].get("role") != "tool":
truncated_messages.insert(1, messages[i])

remaining_count -= 1
if remaining_count == 0:
break

return truncated_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_len = len(pre_transform_messages)
Expand Down
99 changes: 62 additions & 37 deletions test/agentchat/contrib/capabilities/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from autogen.agentchat.contrib.capabilities.transforms import (
MessageHistoryLimiter,
MessageTokenLimiter,
TextMessageCompressor,
_count_tokens,
TextMessageCompressor
)
from autogen.agentchat.contrib.capabilities.transforms_util import (
count_text_tokens
)


Expand Down Expand Up @@ -40,6 +42,25 @@ def get_no_content_messages() -> List[Dict]:
return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]


def get_tool_messages() -> List[Dict]:
return [
{"role": "user", "content": "hello"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
{"role": "user", "content": "how are you"},
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
]

def get_tool_messages_kept() -> List[Dict]:
return [
{"role": "user", "content": "hello"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
]


def get_text_compressors() -> List[TextCompressor]:
compressors: List[TextCompressor] = [_MockTextCompressor()]
try:
Expand All @@ -56,6 +77,10 @@ def get_text_compressors() -> List[TextCompressor]:
def message_history_limiter() -> MessageHistoryLimiter:
return MessageHistoryLimiter(max_messages=3)

@pytest.fixture
def message_history_limiter_keep_first() -> MessageHistoryLimiter:
return MessageHistoryLimiter(max_messages=3, keep_first_message=True)


@pytest.fixture
def message_token_limiter() -> MessageTokenLimiter:
Expand Down Expand Up @@ -96,19 +121,38 @@ def _filter_dict_test(

@pytest.mark.parametrize(
"messages, expected_messages_len",
[(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2)],
[(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2), (get_tool_messages(), 2), (get_tool_messages_kept(), 2)],
)
def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len):
transformed_messages = message_history_limiter.apply_transform(messages)
assert len(transformed_messages) == expected_messages_len

if messages == get_tool_messages_kept():
assert transformed_messages[0]["role"] == "tool_calls"
assert transformed_messages[1]["role"] == "tool"


@pytest.mark.parametrize(
"messages, expected_messages_len",
[(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2), (get_tool_messages(), 3), (get_tool_messages_kept(), 3)],
)
def test_message_history_limiter_apply_transform_keep_first(message_history_limiter_keep_first, messages, expected_messages_len):
transformed_messages = message_history_limiter_keep_first.apply_transform(messages)
assert len(transformed_messages) == expected_messages_len

if messages == get_tool_messages_kept():
assert transformed_messages[1]["role"] == "tool_calls"
assert transformed_messages[2]["role"] == "tool"


@pytest.mark.parametrize(
"messages, expected_logs, expected_effect",
[
(get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True),
(get_short_messages(), "No messages were removed.", False),
(get_no_content_messages(), "No messages were removed.", False),
(get_tool_messages(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True),
(get_tool_messages_kept(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True),
],
)
def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect):
Expand All @@ -131,7 +175,7 @@ def test_message_token_limiter_apply_transform(
):
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
assert (
sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
)
assert len(transformed_messages) == expected_messages_len

Expand Down Expand Up @@ -167,7 +211,7 @@ def test_message_token_limiter_with_threshold_apply_transform(
):
transformed_messages = message_token_limiter_with_threshold.apply_transform(messages)
assert (
sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
)
assert len(transformed_messages) == expected_messages_len

Expand Down Expand Up @@ -240,56 +284,31 @@ def test_text_compression_with_filter(messages, text_compressor):
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False)


@pytest.mark.parametrize("text_compressor", get_text_compressors())
def test_text_compression_cache(text_compressor):
messages = get_long_messages()
mock_compressed_content = (1, {"content": "mock"})

with patch(
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_get",
MagicMock(return_value=(1, {"content": "mock"})),
) as mocked_get, patch(
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock()
) as mocked_set:
compressor = TextMessageCompressor(text_compressor=text_compressor)

compressor.apply_transform(messages)
compressor.apply_transform(messages)

assert mocked_get.call_count == len(messages)
assert mocked_set.call_count == len(messages)

# We already populated the cache with the mock content
# We need to test if we retrieve the correct content
compressor = TextMessageCompressor(text_compressor=text_compressor)
compressed_messages = compressor.apply_transform(messages)

for message in compressed_messages:
assert message["content"] == mock_compressed_content[1]


if __name__ == "__main__":
long_messages = get_long_messages()
short_messages = get_short_messages()
no_content_messages = get_no_content_messages()
tool_messages = get_tool_messages()
msg_history_limiter = MessageHistoryLimiter(max_messages=3)
msg_history_limiter_keep_first = MessageHistoryLimiter(max_messages=3, keep_first=True)
msg_token_limiter = MessageTokenLimiter(max_tokens_per_message=3)
msg_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10)

# Test Parameters
message_history_limiter_apply_transform_parameters = {
"messages": [long_messages, short_messages, no_content_messages],
"expected_messages_len": [3, 3, 2],
"messages": [long_messages, short_messages, no_content_messages, tool_messages],
"expected_messages_len": [3, 3, 2, 4],
}

message_history_limiter_get_logs_parameters = {
"messages": [long_messages, short_messages, no_content_messages],
"messages": [long_messages, short_messages, no_content_messages, tool_messages],
"expected_logs": [
"Removed 2 messages. Number of messages reduced from 5 to 3.",
"No messages were removed.",
"No messages were removed.",
"Removed 1 messages. Number of messages reduced from 5 to 4.",
],
"expected_effect": [True, False, False],
"expected_effect": [True, False, False, True],
}

message_token_limiter_apply_transform_parameters = {
Expand Down Expand Up @@ -322,6 +341,12 @@ def test_text_compression_cache(text_compressor):
):
test_message_history_limiter_apply_transform(msg_history_limiter, messages, expected_messages_len)

for messages, expected_messages_len in zip(
message_history_limiter_apply_transform_parameters["messages"],
message_history_limiter_apply_transform_parameters["expected_messages_len"],
):
test_message_history_limiter_apply_transform_keep_first(msg_history_limiter_keep_first, messages, expected_messages_len)

for messages, expected_logs, expected_effect in zip(
message_history_limiter_get_logs_parameters["messages"],
message_history_limiter_get_logs_parameters["expected_logs"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,28 @@ pprint.pprint(processed_messages)
{'content': 'very very very very very very long string', 'role': 'user'}]
```

By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages.
By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages. However, if the splitting point is between a "tool_calls" and "tool" pair, the complete pair will be included to obey the OpenAI API call constraints.

```python
max_msg_transfrom = transforms.MessageHistoryLimiter(max_messages=3)

messages = [
{"role": "user", "content": "hello"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
{"role": "user", "content": "how are you"},
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
]

processed_messages = max_msg_transfrom.apply_transform(copy.deepcopy(messages))
pprint.pprint(processed_messages)
```
```console
[{'content': 'calling_tool', 'role': 'tool_calls'},
{'content': 'tool_response', 'role': 'tool'},
{'content': 'how are you', 'role': 'user'},
{'content': [{'text': 'are you doing?', 'type': 'text'}], 'role': 'assistant'}]
```

#### Example 2: Limiting the Number of Tokens

Expand Down
Loading