Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proxy PR for Long Context Capability 1513 #1591

Merged
merged 31 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1eb1184
Add new capability to handle long context
gagb Feb 2, 2024
a30a1d0
Make print conditional
gagb Feb 2, 2024
495bd90
Remove superfluous comment
gagb Feb 2, 2024
6ca3cab
Fix msg order
gagb Feb 2, 2024
d55264a
Allow user to specify max_tokens
gagb Feb 2, 2024
8ac4e8c
Add ability to specify max_tokens per message; improve name
gagb Feb 3, 2024
e8a3aae
Improve doc and readability
gagb Feb 3, 2024
7ab5543
Add tests
gagb Feb 3, 2024
c72e534
Merge branch 'main' into long_context
gagb Feb 3, 2024
96b2e6d
Improve documentation and add tests per Erik and Chi's feedback
gagb Feb 5, 2024
7854833
Update notebook
gagb Feb 5, 2024
3d5b8ee
Update doc string of add to agents
gagb Feb 5, 2024
752c781
Improve doc string
gagb Feb 5, 2024
7023d75
improve notebook
gagb Feb 5, 2024
7f99dc9
Update github workflows for context handling
gagb Feb 6, 2024
c3acdf8
Update docstring
ekzhu Feb 6, 2024
4f76d41
Merge branch 'long_context' of https://github.com/gagb/autogen into p…
ekzhu Feb 6, 2024
31932a8
update notebook to use raw config list.
ekzhu Feb 6, 2024
4f08497
Merge branch 'main' into long_context
sonichi Feb 6, 2024
567ca21
Update contrib-openai.yml remove _target
gagb Feb 6, 2024
165485e
Merge branch 'main' into long_context
gagb Feb 6, 2024
4c74d73
Fix code formatting
gagb Feb 6, 2024
89188a9
Merge branch 'main' into long_context
sonichi Feb 7, 2024
336f429
Merge branch 'main' into long_context
gagb Feb 8, 2024
283d329
Merge pull request #1590 from gagb/long_context
gagb Feb 8, 2024
4a84e3f
Fix workflow file
gagb Feb 8, 2024
e3b75aa
Merge branch 'main' into long_context
gagb Feb 8, 2024
eb664a6
Merge branch 'main' into long_context
sonichi Feb 8, 2024
a9c7d47
Merge branch 'main' into long_context
gagb Feb 8, 2024
ccc0362
Update .github/workflows/contrib-openai.yml
sonichi Feb 8, 2024
2bae6ee
Merge branch 'main' into long_context
sonichi Feb 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion .github/workflows/contrib-openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name: OpenAI4ContribTests

on:
pull_request_target:
pull_request:
sonichi marked this conversation as resolved.
Show resolved Hide resolved
branches: ['main']
paths:
- 'autogen/**'
Expand Down Expand Up @@ -260,3 +260,42 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
ContextHandling:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.11"]
runs-on: ${{ matrix.os }}
environment: openai1
steps:
# checkout to pr branch
- name: Checkout
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies
run: |
docker --version
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
pip install coverage pytest
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
BING_API_KEY: ${{ secrets.BING_API_KEY }}
run: |
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
37 changes: 37 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,40 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

ContextHandling:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for Context Handling
run: |
pip install -e .
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
5 changes: 0 additions & 5 deletions autogen/agentchat/contrib/capabilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from .teachability import Teachability
from .agent_capability import AgentCapability


__all__ = ["Teachability", "AgentCapability"]
108 changes: 108 additions & 0 deletions autogen/agentchat/contrib/capabilities/context_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import sys
from termcolor import colored
from typing import Dict, Optional, List
from autogen import ConversableAgent
from autogen import token_count_utils


class TransformChatHistory:
"""
An agent's chat history with other agents is a common context that it uses to generate a reply.
This capability allows the agent to transform its chat history prior to using it to generate a reply.
It does not permanently modify the chat history, but rather processes it on every invocation.

This capability class enables various strategies to transform chat history, such as:
- Truncate messages: Truncate each message to first maximum number of tokens.
- Limit number of messages: Truncate the chat history to a maximum number of (recent) messages.
- Limit number of tokens: Truncate the chat history to number of recent N messages that fit in
maximum number of tokens.
Note that the system message, because of its special significance, is always kept as is.

The three strategies can be combined. For example, when each of these parameters are specified
they are used in the following order:
1. First truncate messages to a maximum number of tokens
2. Second, it limits the number of message to keep
3. Third, it limits the total number of tokens in the chat history

Args:
max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message.
max_messages (Optional[int]): Maximum number of messages to keep in the context.
max_tokens (Optional[int]): Maximum number of tokens to keep in the context.
"""

def __init__(
self,
*,
max_tokens_per_message: Optional[int] = None,
max_messages: Optional[int] = None,
max_tokens: Optional[int] = None,
):
self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize
self.max_messages = max_messages if max_messages else sys.maxsize
self.max_tokens = max_tokens if max_tokens else sys.maxsize

def add_to_agent(self, agent: ConversableAgent):
"""
Adds TransformChatHistory capability to the given agent.
"""
agent.register_hook(hookable_method=agent.process_all_messages, hook=self._transform_messages)

def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Args:
messages: List of messages to process.

Returns:
List of messages with the first system message and the last max_messages messages.
"""
processed_messages = []
messages = messages.copy()
rest_messages = messages

# check if the first message is a system message and append it to the processed messages
if len(messages) > 0:
if messages[0]["role"] == "system":
msg = messages[0]
processed_messages.append(msg)
rest_messages = messages[1:]

processed_messages_tokens = 0
for msg in messages:
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)

# iterate through rest of the messages and append them to the processed messages
for msg in rest_messages[-self.max_messages :]:
msg_tokens = token_count_utils.count_token(msg["content"])
if processed_messages_tokens + msg_tokens > self.max_tokens:
break
processed_messages.append(msg)
processed_messages_tokens += msg_tokens

total_tokens = 0
for msg in messages:
total_tokens += token_count_utils.count_token(msg["content"])

num_truncated = len(messages) - len(processed_messages)
if num_truncated > 0 or total_tokens > processed_messages_tokens:
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
return processed_messages


def truncate_str_to_tokens(text: str, max_tokens: int) -> str:
"""
Truncate a string so that number of tokens in less than max_tokens.

Args:
content: String to process.
max_tokens: Maximum number of tokens to keep.

Returns:
Truncated string.
"""
truncated_string = ""
for char in text:
truncated_string += char
if token_count_utils.count_token(truncated_string) == max_tokens:
break
return truncated_string
25 changes: 24 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
self.hook_lists = {self.process_last_message: [], self.process_all_messages: []}

def register_reply(
self,
Expand Down Expand Up @@ -1524,6 +1524,10 @@ def generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
Expand Down Expand Up @@ -1580,6 +1584,10 @@ async def a_generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
Expand Down Expand Up @@ -2207,6 +2215,21 @@ def register_hook(self, hookable_method: Callable, hook: Callable):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)

def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
"""
hook_list = self.hook_lists[self.process_all_messages]
# If no hooks are registered, or if there are no messages to process, return the original message list.
if len(hook_list) == 0 or messages is None:
return messages

# Call each hook (in order of registration) to process the messages.
processed_messages = messages
for hook in hook_list:
processed_messages = hook(processed_messages)
return processed_messages

def process_last_message(self, messages):
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
Expand Down
Loading
Loading