Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform Messages Capability #1923

Merged
merged 50 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1c698d1
wip
WaelKarkoub Mar 8, 2024
236c05b
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 8, 2024
f94d293
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 8, 2024
fc1ca64
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 9, 2024
632321e
Adds docstrings
WaelKarkoub Mar 9, 2024
dba8c8b
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 13, 2024
97de5e4
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 13, 2024
c528e3a
fixed spellings
WaelKarkoub Mar 14, 2024
df63277
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 14, 2024
2174274
wip
WaelKarkoub Mar 14, 2024
d222af3
fixed errors
WaelKarkoub Mar 14, 2024
f8e8402
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 16, 2024
4d29b65
better class names
WaelKarkoub Mar 16, 2024
48053b4
adds tests
WaelKarkoub Mar 16, 2024
77d38c7
added tests to workflow
WaelKarkoub Mar 16, 2024
c85171a
improved token counting
WaelKarkoub Mar 16, 2024
b3f4d89
improved notebook
WaelKarkoub Mar 16, 2024
2dd7898
improved token counting in test
WaelKarkoub Mar 16, 2024
89a1e52
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 16, 2024
a692d16
improved docstrings
WaelKarkoub Mar 16, 2024
b1d4ebe
fix inconsistencies
WaelKarkoub Mar 16, 2024
4389ac0
changed by mistake
WaelKarkoub Mar 16, 2024
9a81321
fixed docstring
WaelKarkoub Mar 16, 2024
83eed2f
fixed details
WaelKarkoub Mar 16, 2024
0224641
improves tests + adds openai contrib test
WaelKarkoub Mar 16, 2024
bfd572d
fix spelling oai contrib test
WaelKarkoub Mar 17, 2024
9fd9611
clearer docstrings
WaelKarkoub Mar 17, 2024
8084b01
remove repeated docstr
WaelKarkoub Mar 17, 2024
35e5694
improved notebook
WaelKarkoub Mar 17, 2024
34f1892
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 17, 2024
7408562
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 18, 2024
6255327
adds metadata to notebook
WaelKarkoub Mar 18, 2024
b94debc
fix merge
WaelKarkoub Mar 22, 2024
66eb122
Improve outline and description (#2125)
gagb Mar 23, 2024
02adef3
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 24, 2024
d597a1c
better dir structure
WaelKarkoub Mar 24, 2024
36426f1
clip max tokens to allowed tokens
WaelKarkoub Mar 24, 2024
93014db
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 25, 2024
b4f7248
more accurate comments/docstrs
WaelKarkoub Mar 25, 2024
7c66348
add deperecation warning
WaelKarkoub Mar 25, 2024
3d168cb
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 25, 2024
5e7423c
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 26, 2024
2fb4589
fix front matter desc
WaelKarkoub Mar 26, 2024
3156cc2
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 27, 2024
0f23333
add deperecation warning notebook
WaelKarkoub Mar 27, 2024
8b0e77a
undo local notebook settings changes
WaelKarkoub Mar 27, 2024
4fe8aa9
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 28, 2024
a8b3c5a
format notebook
WaelKarkoub Mar 28, 2024
a68780c
format workflow
WaelKarkoub Mar 28, 2024
3c29935
Merge branch 'main' into transform-messages-capability
WaelKarkoub Mar 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,40 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

TransformMessages:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for Transform Messages
run: |
pip install -e .
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_transform_messages.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittest
17 changes: 13 additions & 4 deletions autogen/agentchat/contrib/capabilities/context_handling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import sys
from termcolor import colored
from typing import Dict, Optional, List
from autogen import ConversableAgent
from autogen import token_count_utils
from typing import Dict, List, Optional
from warnings import warn

import tiktoken
from termcolor import colored

from autogen import ConversableAgent, token_count_utils

warn(
"Context handling with TransformChatHistory is deprecated. "
WaelKarkoub marked this conversation as resolved.
Show resolved Hide resolved
"Please use TransformMessages from autogen/agentchat/contrib/capabilities/transform_messages.py instead.",
DeprecationWarning,
stacklevel=2,
)


class TransformChatHistory:
Expand Down
90 changes: 90 additions & 0 deletions autogen/agentchat/contrib/capabilities/transform_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import copy
from typing import Dict, List

from termcolor import colored

from autogen import ConversableAgent
from .transforms import MessageTransform


class TransformMessages:
"""Agent capability for transforming messages before reply generation.

This capability allows you to apply a series of message transformations to
a ConversableAgent's incoming messages before they are processed for response
generation. This is useful for tasks such as:

- Limiting the number of messages considered for context.
- Truncating messages to meet token limits.
- Filtering sensitive information.
- Customizing message formatting.

To use `TransformMessages`:

1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`).
2. Instantiate `TransformMessages` with a list of these transformations.
3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`.

NOTE: Order of message transformations is important. You could get different results based on
the order of transformations.

Example:
```python
from agentchat import ConversableAgent
from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter

max_messages = MessageHistoryLimiter(max_messages=2)
truncate_messages = MessageTokenLimiter(max_tokens=500)
transform_messages = TransformMessages(transforms=[max_messages, truncate_messages])

agent = ConversableAgent(...)
transform_messages.add_to_agent(agent)
```
"""

def __init__(self, *, transforms: List[MessageTransform] = []):
"""
Args:
transforms: A list of message transformations to apply.
"""
self._transforms = transforms

def add_to_agent(self, agent: ConversableAgent):
"""Adds the message transformations capability to the specified ConversableAgent.

This function performs the following modifications to the agent:

1. Registers a hook that automatically transforms all messages before they are processed for
response generation.
"""
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)

def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
temp_messages = copy.deepcopy(messages)
system_message = None

if messages[0]["role"] == "system":
system_message = copy.deepcopy(messages[0])
temp_messages.pop(0)

for transform in self._transforms:
temp_messages = transform.apply_transform(temp_messages)

if system_message:
temp_messages.insert(0, system_message)

self._print_stats(messages, temp_messages)

return temp_messages

def _print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]):
pre_transform_messages_len = len(pre_transform_messages)
post_transform_messages_len = len(post_transform_messages)

if pre_transform_messages_len < post_transform_messages_len:
print(
colored(
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.",
"yellow",
)
)
210 changes: 210 additions & 0 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import sys
from typing import Any, Dict, List, Optional, Protocol, Union

import tiktoken
from termcolor import colored

from autogen import token_count_utils


class MessageTransform(Protocol):
"""Defines a contract for message transformation.

Classes implementing this protocol should provide an `apply_transform` method
that takes a list of messages and returns the transformed list.
"""

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies a transformation to a list of messages.

Args:
messages: A list of dictionaries representing messages.

Returns:
A new list of dictionaries containing the transformed messages.
"""
...


class MessageHistoryLimiter:
"""Limits the number of messages considered by an agent for response generation.

This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages).
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""

def __init__(self, max_messages: Optional[int] = None):
"""
Args:
max_messages (None or int): Maximum number of messages to keep in the context.
Must be greater than 0 if not None.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Truncates the conversation history to the specified maximum number of messages.

This method returns a new list containing the most recent messages up to the specified
maximum number of messages (max_messages). If max_messages is None, it returns the
original list of messages unmodified.

Args:
messages (List[Dict]): The list of messages representing the conversation history.

Returns:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""
if self._max_messages is None:
return messages

return messages[-self._max_messages :]

def _validate_max_messages(self, max_messages: Optional[int]):
if max_messages is not None and max_messages < 1:
raise ValueError("max_messages must be None or greater than 1")


class MessageTokenLimiter:
"""Truncates messages to meet token limits for efficient processing and response generation.

This transformation applies two levels of truncation to the conversation history:

1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message.
2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens.

NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token
counts for the same text.

NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input
(e.g images).

The truncation process follows these steps in order:

1. Messages are processed in reverse order (newest to oldest).
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed as well as any remaining messages are discarded.
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""

def __init__(
self,
max_tokens_per_message: Optional[int] = None,
max_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
):
"""
Args:
max_tokens_per_message (None or int): Maximum number of tokens to keep in each message.
Must be greater than or equal to 0 if not None.
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.

Args:
messages (List[Dict]): The list of messages representing the conversation history.

Returns:
List[Dict]: A new list containing the truncated messages up to the specified token limits.
"""
assert self._max_tokens_per_message is not None
assert self._max_tokens is not None

temp_messages = messages.copy()
processed_messages = []
processed_messages_tokens = 0

# calculate tokens for all messages
total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages)

for msg in reversed(temp_messages):
msg["content"] = self._truncate_str_to_tokens(msg["content"])
msg_tokens = _count_tokens(msg["content"])

# If adding this message would exceed the token limit, discard it and all remaining messages
if processed_messages_tokens + msg_tokens > self._max_tokens:
break

# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
processed_messages.insert(0, msg)

if total_tokens > processed_messages_tokens:
print(
colored(
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
"yellow",
)
)

return processed_messages

def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents)
elif isinstance(contents, list):
return self._truncate_multimodal_text(contents)
else:
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")

def _truncate_multimodal_text(self, contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
tmp_contents = []
for content in contents:
if content["type"] == "text":
truncated_text = self._truncate_tokens(content["text"])
tmp_contents.append({"type": "text", "text": truncated_text})
else:
tmp_contents.append(content)
return tmp_contents

def _truncate_tokens(self, text: str) -> str:
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer

encoded_tokens = encoding.encode(text)
truncated_tokens = encoded_tokens[: self._max_tokens_per_message]
truncated_text = encoding.decode(truncated_tokens) # Decode back to text

return truncated_text

def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]:
if max_tokens is not None and max_tokens < 0:
raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0")

try:
allowed_tokens = token_count_utils.get_max_token_limit(self._model)
except Exception:
print(colored(f"Model {self._model} not found in token_count_utils.", "yellow"))
allowed_tokens = None

if max_tokens is not None and allowed_tokens is not None:
if max_tokens > allowed_tokens:
print(
colored(
f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.",
"yellow",
)
)
return allowed_tokens

return max_tokens if max_tokens is not None else sys.maxsize


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count
4 changes: 4 additions & 0 deletions notebook/agentchat_capability_long_context_handling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
"source": [
"# Handling A Long Context via `TransformChatHistory`\n",
"\n",
"<div class=\"alert alert-warning\" role=\"alert\">\n",
" <strong>Deprecation Notice:</strong> <code>TransformChatHistory</code> is no longer supported. Please use <code>TransformMessages</code> as the new standard method. For the latest examples, visit the notebook at <a href=\"https://github.com/microsoft/autogen/blob/main/notebook/agentchat_transform_messages.ipynb\" target=\"_blank\">notebook/agentchat_transform_messages.ipynb</a>.\n",
"</div>\n",
"\n",
"This notebook illustrates how you can use the `TransformChatHistory` capability to give any `Conversable` agent an ability to handle a long context. \n",
"\n",
"````{=mdx}\n",
Expand Down
Loading
Loading