Skip to content

Commit 47d6c75

Browse files
gagbekzhusonichi
authored
Proxy PR for Long Context Capability 1513 (#1591)
* Add new capability to handle long context * Make print conditional * Remove superfluous comment * Fix msg order * Allow user to specify max_tokens * Add ability to specify max_tokens per message; improve name * Improve doc and readability * Add tests * Improve documentation and add tests per Erik and Chi's feedback * Update notebook * Update doc string of add to agents * Improve doc string * improve notebook * Update github workflows for context handling * Update docstring * update notebook to use raw config list. * Update contrib-openai.yml remove _target * Fix code formatting * Fix workflow file * Update .github/workflows/contrib-openai.yml --------- Co-authored-by: Eric Zhu <[email protected]> Co-authored-by: Chi Wang <[email protected]>
1 parent a3c3317 commit 47d6c75

File tree

7 files changed

+920
-6
lines changed

7 files changed

+920
-6
lines changed

.github/workflows/contrib-openai.yml

+39
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,42 @@ jobs:
260260
with:
261261
file: ./coverage.xml
262262
flags: unittests
263+
ContextHandling:
264+
strategy:
265+
matrix:
266+
os: [ubuntu-latest]
267+
python-version: ["3.11"]
268+
runs-on: ${{ matrix.os }}
269+
environment: openai1
270+
steps:
271+
# checkout to pr branch
272+
- name: Checkout
273+
uses: actions/checkout@v3
274+
with:
275+
ref: ${{ github.event.pull_request.head.sha }}
276+
- name: Set up Python ${{ matrix.python-version }}
277+
uses: actions/setup-python@v4
278+
with:
279+
python-version: ${{ matrix.python-version }}
280+
- name: Install packages and dependencies
281+
run: |
282+
docker --version
283+
python -m pip install --upgrade pip wheel
284+
pip install -e .
285+
python -c "import autogen"
286+
pip install coverage pytest
287+
- name: Coverage
288+
env:
289+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
290+
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
291+
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
292+
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
293+
BING_API_KEY: ${{ secrets.BING_API_KEY }}
294+
run: |
295+
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py
296+
coverage xml
297+
- name: Upload coverage to Codecov
298+
uses: codecov/codecov-action@v3
299+
with:
300+
file: ./coverage.xml
301+
flags: unittests

.github/workflows/contrib-tests.yml

+37
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,40 @@ jobs:
253253
with:
254254
file: ./coverage.xml
255255
flags: unittests
256+
257+
ContextHandling:
258+
runs-on: ${{ matrix.os }}
259+
strategy:
260+
fail-fast: false
261+
matrix:
262+
os: [ubuntu-latest, macos-latest, windows-2019]
263+
python-version: ["3.11"]
264+
steps:
265+
- uses: actions/checkout@v3
266+
- name: Set up Python ${{ matrix.python-version }}
267+
uses: actions/setup-python@v4
268+
with:
269+
python-version: ${{ matrix.python-version }}
270+
- name: Install packages and dependencies for all tests
271+
run: |
272+
python -m pip install --upgrade pip wheel
273+
pip install pytest
274+
- name: Install packages and dependencies for Context Handling
275+
run: |
276+
pip install -e .
277+
- name: Set AUTOGEN_USE_DOCKER based on OS
278+
shell: bash
279+
run: |
280+
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
281+
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
282+
fi
283+
- name: Coverage
284+
run: |
285+
pip install coverage>=5.3
286+
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py --skip-openai
287+
coverage xml
288+
- name: Upload coverage to Codecov
289+
uses: codecov/codecov-action@v3
290+
with:
291+
file: ./coverage.xml
292+
flags: unittests
Original file line numberDiff line numberDiff line change
@@ -1,5 +0,0 @@
1-
from .teachability import Teachability
2-
from .agent_capability import AgentCapability
3-
4-
5-
__all__ = ["Teachability", "AgentCapability"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import sys
2+
from termcolor import colored
3+
from typing import Dict, Optional, List
4+
from autogen import ConversableAgent
5+
from autogen import token_count_utils
6+
7+
8+
class TransformChatHistory:
9+
"""
10+
An agent's chat history with other agents is a common context that it uses to generate a reply.
11+
This capability allows the agent to transform its chat history prior to using it to generate a reply.
12+
It does not permanently modify the chat history, but rather processes it on every invocation.
13+
14+
This capability class enables various strategies to transform chat history, such as:
15+
- Truncate messages: Truncate each message to first maximum number of tokens.
16+
- Limit number of messages: Truncate the chat history to a maximum number of (recent) messages.
17+
- Limit number of tokens: Truncate the chat history to number of recent N messages that fit in
18+
maximum number of tokens.
19+
Note that the system message, because of its special significance, is always kept as is.
20+
21+
The three strategies can be combined. For example, when each of these parameters are specified
22+
they are used in the following order:
23+
1. First truncate messages to a maximum number of tokens
24+
2. Second, it limits the number of message to keep
25+
3. Third, it limits the total number of tokens in the chat history
26+
27+
Args:
28+
max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message.
29+
max_messages (Optional[int]): Maximum number of messages to keep in the context.
30+
max_tokens (Optional[int]): Maximum number of tokens to keep in the context.
31+
"""
32+
33+
def __init__(
34+
self,
35+
*,
36+
max_tokens_per_message: Optional[int] = None,
37+
max_messages: Optional[int] = None,
38+
max_tokens: Optional[int] = None,
39+
):
40+
self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize
41+
self.max_messages = max_messages if max_messages else sys.maxsize
42+
self.max_tokens = max_tokens if max_tokens else sys.maxsize
43+
44+
def add_to_agent(self, agent: ConversableAgent):
45+
"""
46+
Adds TransformChatHistory capability to the given agent.
47+
"""
48+
agent.register_hook(hookable_method=agent.process_all_messages, hook=self._transform_messages)
49+
50+
def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
51+
"""
52+
Args:
53+
messages: List of messages to process.
54+
55+
Returns:
56+
List of messages with the first system message and the last max_messages messages.
57+
"""
58+
processed_messages = []
59+
messages = messages.copy()
60+
rest_messages = messages
61+
62+
# check if the first message is a system message and append it to the processed messages
63+
if len(messages) > 0:
64+
if messages[0]["role"] == "system":
65+
msg = messages[0]
66+
processed_messages.append(msg)
67+
rest_messages = messages[1:]
68+
69+
processed_messages_tokens = 0
70+
for msg in messages:
71+
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
72+
73+
# iterate through rest of the messages and append them to the processed messages
74+
for msg in rest_messages[-self.max_messages :]:
75+
msg_tokens = token_count_utils.count_token(msg["content"])
76+
if processed_messages_tokens + msg_tokens > self.max_tokens:
77+
break
78+
processed_messages.append(msg)
79+
processed_messages_tokens += msg_tokens
80+
81+
total_tokens = 0
82+
for msg in messages:
83+
total_tokens += token_count_utils.count_token(msg["content"])
84+
85+
num_truncated = len(messages) - len(processed_messages)
86+
if num_truncated > 0 or total_tokens > processed_messages_tokens:
87+
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
88+
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
89+
return processed_messages
90+
91+
92+
def truncate_str_to_tokens(text: str, max_tokens: int) -> str:
93+
"""
94+
Truncate a string so that number of tokens in less than max_tokens.
95+
96+
Args:
97+
content: String to process.
98+
max_tokens: Maximum number of tokens to keep.
99+
100+
Returns:
101+
Truncated string.
102+
"""
103+
truncated_string = ""
104+
for char in text:
105+
truncated_string += char
106+
if token_count_utils.count_token(truncated_string) == max_tokens:
107+
break
108+
return truncated_string

autogen/agentchat/conversable_agent.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(
194194

195195
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
196196
# New hookable methods should be added to this list as required to support new agent capabilities.
197-
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
197+
self.hook_lists = {self.process_last_message: [], self.process_all_messages: []}
198198

199199
def register_reply(
200200
self,
@@ -1528,6 +1528,10 @@ def generate_reply(
15281528
if messages is None:
15291529
messages = self._oai_messages[sender]
15301530

1531+
# Call the hookable method that gives registered hooks a chance to process all messages.
1532+
# Message modifications do not affect the incoming messages or self._oai_messages.
1533+
messages = self.process_all_messages(messages)
1534+
15311535
# Call the hookable method that gives registered hooks a chance to process the last message.
15321536
# Message modifications do not affect the incoming messages or self._oai_messages.
15331537
messages = self.process_last_message(messages)
@@ -1584,6 +1588,10 @@ async def a_generate_reply(
15841588
if messages is None:
15851589
messages = self._oai_messages[sender]
15861590

1591+
# Call the hookable method that gives registered hooks a chance to process all messages.
1592+
# Message modifications do not affect the incoming messages or self._oai_messages.
1593+
messages = self.process_all_messages(messages)
1594+
15871595
# Call the hookable method that gives registered hooks a chance to process the last message.
15881596
# Message modifications do not affect the incoming messages or self._oai_messages.
15891597
messages = self.process_last_message(messages)
@@ -2211,6 +2219,21 @@ def register_hook(self, hookable_method: Callable, hook: Callable):
22112219
assert hook not in hook_list, f"{hook} is already registered as a hook."
22122220
hook_list.append(hook)
22132221

2222+
def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
2223+
"""
2224+
Calls any registered capability hooks to process all messages, potentially modifying the messages.
2225+
"""
2226+
hook_list = self.hook_lists[self.process_all_messages]
2227+
# If no hooks are registered, or if there are no messages to process, return the original message list.
2228+
if len(hook_list) == 0 or messages is None:
2229+
return messages
2230+
2231+
# Call each hook (in order of registration) to process the messages.
2232+
processed_messages = messages
2233+
for hook in hook_list:
2234+
processed_messages = hook(processed_messages)
2235+
return processed_messages
2236+
22142237
def process_last_message(self, messages):
22152238
"""
22162239
Calls any registered capability hooks to use and potentially modify the text of the last message,

0 commit comments

Comments
 (0)