Skip to content

Commit

Permalink
Fix message history limiter for tool call (#3178)
Browse files Browse the repository at this point in the history
* fix: message history limiter to support tool calls

* add: pytest and docs for message history limiter for tool calls

* Added keep_first_message for HistoryLimiter transform

* Update to inbetween to between

* Updated keep_first_message to non-optional, logic for history limiter

* Update transforms.py

* Update test_transforms to match utils introduction, add keep_first_message testing

* Update test_transforms.py for pre-commit checks

---------

Co-authored-by: Mark Sze <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2024
1 parent fb788c3 commit 972b4ed
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 40 deletions.
30 changes: 27 additions & 3 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ class MessageHistoryLimiter:
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""

def __init__(self, max_messages: Optional[int] = None):
def __init__(self, max_messages: Optional[int] = None, keep_first_message: bool = False):
"""
Args:
max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
keep_first_message bool: Whether to keep the original first message in the conversation history.
Defaults to False.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages
self._keep_first_message = keep_first_message

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Truncates the conversation history to the specified maximum number of messages.
Expand All @@ -75,10 +78,31 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""

if self._max_messages is None:
if self._max_messages is None or len(messages) <= self._max_messages:
return messages

return messages[-self._max_messages :]
truncated_messages = []
remaining_count = self._max_messages

# Start with the first message if we need to keep it
if self._keep_first_message:
truncated_messages = [messages[0]]
remaining_count -= 1

# Loop through messages in reverse
for i in range(len(messages) - 1, 0, -1):
if remaining_count > 1:
truncated_messages.insert(1 if self._keep_first_message else 0, messages[i])
if remaining_count == 1:
# If there's only 1 slot left and it's a 'tools' message, ignore it.
if messages[i].get("role") != "tool":
truncated_messages.insert(1, messages[i])

remaining_count -= 1
if remaining_count == 0:
break

return truncated_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_len = len(pre_transform_messages)
Expand Down
115 changes: 79 additions & 36 deletions test/agentchat/contrib/capabilities/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
MessageHistoryLimiter,
MessageTokenLimiter,
TextMessageCompressor,
_count_tokens,
)
from autogen.agentchat.contrib.capabilities.transforms_util import count_text_tokens


class _MockTextCompressor:
Expand Down Expand Up @@ -40,6 +40,26 @@ def get_no_content_messages() -> List[Dict]:
return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]


def get_tool_messages() -> List[Dict]:
return [
{"role": "user", "content": "hello"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
{"role": "user", "content": "how are you"},
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
]


def get_tool_messages_kept() -> List[Dict]:
return [
{"role": "user", "content": "hello"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
]


def get_text_compressors() -> List[TextCompressor]:
compressors: List[TextCompressor] = [_MockTextCompressor()]
try:
Expand All @@ -57,6 +77,11 @@ def message_history_limiter() -> MessageHistoryLimiter:
return MessageHistoryLimiter(max_messages=3)


@pytest.fixture
def message_history_limiter_keep_first() -> MessageHistoryLimiter:
return MessageHistoryLimiter(max_messages=3, keep_first_message=True)


@pytest.fixture
def message_token_limiter() -> MessageTokenLimiter:
return MessageTokenLimiter(max_tokens_per_message=3)
Expand Down Expand Up @@ -96,19 +121,52 @@ def _filter_dict_test(

@pytest.mark.parametrize(
"messages, expected_messages_len",
[(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2)],
[
(get_long_messages(), 3),
(get_short_messages(), 3),
(get_no_content_messages(), 2),
(get_tool_messages(), 2),
(get_tool_messages_kept(), 2),
],
)
def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len):
transformed_messages = message_history_limiter.apply_transform(messages)
assert len(transformed_messages) == expected_messages_len

if messages == get_tool_messages_kept():
assert transformed_messages[0]["role"] == "tool_calls"
assert transformed_messages[1]["role"] == "tool"


@pytest.mark.parametrize(
"messages, expected_messages_len",
[
(get_long_messages(), 3),
(get_short_messages(), 3),
(get_no_content_messages(), 2),
(get_tool_messages(), 3),
(get_tool_messages_kept(), 3),
],
)
def test_message_history_limiter_apply_transform_keep_first(
message_history_limiter_keep_first, messages, expected_messages_len
):
transformed_messages = message_history_limiter_keep_first.apply_transform(messages)
assert len(transformed_messages) == expected_messages_len

if messages == get_tool_messages_kept():
assert transformed_messages[1]["role"] == "tool_calls"
assert transformed_messages[2]["role"] == "tool"


@pytest.mark.parametrize(
"messages, expected_logs, expected_effect",
[
(get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True),
(get_short_messages(), "No messages were removed.", False),
(get_no_content_messages(), "No messages were removed.", False),
(get_tool_messages(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True),
(get_tool_messages_kept(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True),
],
)
def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect):
Expand All @@ -131,7 +189,8 @@ def test_message_token_limiter_apply_transform(
):
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
assert (
sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg)
== expected_token_count
)
assert len(transformed_messages) == expected_messages_len

Expand Down Expand Up @@ -167,7 +226,8 @@ def test_message_token_limiter_with_threshold_apply_transform(
):
transformed_messages = message_token_limiter_with_threshold.apply_transform(messages)
assert (
sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg)
== expected_token_count
)
assert len(transformed_messages) == expected_messages_len

Expand Down Expand Up @@ -240,56 +300,31 @@ def test_text_compression_with_filter(messages, text_compressor):
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False)


@pytest.mark.parametrize("text_compressor", get_text_compressors())
def test_text_compression_cache(text_compressor):
messages = get_long_messages()
mock_compressed_content = (1, {"content": "mock"})

with patch(
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_get",
MagicMock(return_value=(1, {"content": "mock"})),
) as mocked_get, patch(
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock()
) as mocked_set:
compressor = TextMessageCompressor(text_compressor=text_compressor)

compressor.apply_transform(messages)
compressor.apply_transform(messages)

assert mocked_get.call_count == len(messages)
assert mocked_set.call_count == len(messages)

# We already populated the cache with the mock content
# We need to test if we retrieve the correct content
compressor = TextMessageCompressor(text_compressor=text_compressor)
compressed_messages = compressor.apply_transform(messages)

for message in compressed_messages:
assert message["content"] == mock_compressed_content[1]


if __name__ == "__main__":
long_messages = get_long_messages()
short_messages = get_short_messages()
no_content_messages = get_no_content_messages()
tool_messages = get_tool_messages()
msg_history_limiter = MessageHistoryLimiter(max_messages=3)
msg_history_limiter_keep_first = MessageHistoryLimiter(max_messages=3, keep_first=True)
msg_token_limiter = MessageTokenLimiter(max_tokens_per_message=3)
msg_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10)

# Test Parameters
message_history_limiter_apply_transform_parameters = {
"messages": [long_messages, short_messages, no_content_messages],
"expected_messages_len": [3, 3, 2],
"messages": [long_messages, short_messages, no_content_messages, tool_messages],
"expected_messages_len": [3, 3, 2, 4],
}

message_history_limiter_get_logs_parameters = {
"messages": [long_messages, short_messages, no_content_messages],
"messages": [long_messages, short_messages, no_content_messages, tool_messages],
"expected_logs": [
"Removed 2 messages. Number of messages reduced from 5 to 3.",
"No messages were removed.",
"No messages were removed.",
"Removed 1 messages. Number of messages reduced from 5 to 4.",
],
"expected_effect": [True, False, False],
"expected_effect": [True, False, False, True],
}

message_token_limiter_apply_transform_parameters = {
Expand Down Expand Up @@ -322,6 +357,14 @@ def test_text_compression_cache(text_compressor):
):
test_message_history_limiter_apply_transform(msg_history_limiter, messages, expected_messages_len)

for messages, expected_messages_len in zip(
message_history_limiter_apply_transform_parameters["messages"],
message_history_limiter_apply_transform_parameters["expected_messages_len"],
):
test_message_history_limiter_apply_transform_keep_first(
msg_history_limiter_keep_first, messages, expected_messages_len
)

for messages, expected_logs, expected_effect in zip(
message_history_limiter_get_logs_parameters["messages"],
message_history_limiter_get_logs_parameters["expected_logs"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,28 @@ pprint.pprint(processed_messages)
{'content': 'very very very very very very long string', 'role': 'user'}]
```

By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages.
By applying the `MessageHistoryLimiter`, we can see that we were able to limit the context history to the 3 most recent messages. However, if the splitting point is between a "tool_calls" and "tool" pair, the complete pair will be included to obey the OpenAI API call constraints.

```python
max_msg_transfrom = transforms.MessageHistoryLimiter(max_messages=3)

messages = [
{"role": "user", "content": "hello"},
{"role": "tool_calls", "content": "calling_tool"},
{"role": "tool", "content": "tool_response"},
{"role": "user", "content": "how are you"},
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
]

processed_messages = max_msg_transfrom.apply_transform(copy.deepcopy(messages))
pprint.pprint(processed_messages)
```
```console
[{'content': 'calling_tool', 'role': 'tool_calls'},
{'content': 'tool_response', 'role': 'tool'},
{'content': 'how are you', 'role': 'user'},
{'content': [{'text': 'are you doing?', 'type': 'text'}], 'role': 'assistant'}]
```

#### Example 2: Limiting the Number of Tokens

Expand Down

0 comments on commit 972b4ed

Please sign in to comment.