Skip to content

Commit a34e4cc

Browse files
dkirschegagb
andauthored
Refactor transform_messages (#1631)
* refactored code to simplify * optimize function. Instead of iterating over each character, guess at size and then iterate by token. * adding tests * Add missing tests * minor test fix * simplified token truncation by using tiktoken to encode and decode * updated truncated notification message * Fix llm_config spec to use os.environ * Add test case and fix bug in loop --------- Co-authored-by: gagb <[email protected]>
1 parent d8a204a commit a34e4cc

File tree

3 files changed

+407
-199
lines changed

3 files changed

+407
-199
lines changed

autogen/agentchat/contrib/capabilities/context_handling.py

+49-34
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict, Optional, List
44
from autogen import ConversableAgent
55
from autogen import token_count_utils
6+
import tiktoken
67

78

89
class TransformChatHistory:
@@ -53,56 +54,70 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
5354
messages: List of messages to process.
5455
5556
Returns:
56-
List of messages with the first system message and the last max_messages messages.
57+
List of messages with the first system message and the last max_messages messages,
58+
ensuring each message does not exceed max_tokens_per_message.
5759
"""
60+
temp_messages = messages.copy()
5861
processed_messages = []
59-
messages = messages.copy()
60-
rest_messages = messages
62+
system_message = None
63+
processed_messages_tokens = 0
6164

62-
# check if the first message is a system message and append it to the processed messages
63-
if len(messages) > 0:
64-
if messages[0]["role"] == "system":
65-
msg = messages[0]
66-
processed_messages.append(msg)
67-
rest_messages = messages[1:]
65+
if messages[0]["role"] == "system":
66+
system_message = messages[0].copy()
67+
temp_messages.pop(0)
6868

69-
processed_messages_tokens = 0
70-
for msg in messages:
71-
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
69+
total_tokens = sum(
70+
token_count_utils.count_token(msg["content"]) for msg in temp_messages
71+
) # Calculate tokens for all messages
7272

73-
# iterate through rest of the messages and append them to the processed messages
74-
for msg in rest_messages[-self.max_messages :]:
73+
# Truncate each message's content to a maximum token limit of each message
74+
75+
# Process recent messages first
76+
for msg in reversed(temp_messages[-self.max_messages :]):
77+
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
7578
msg_tokens = token_count_utils.count_token(msg["content"])
7679
if processed_messages_tokens + msg_tokens > self.max_tokens:
7780
break
78-
processed_messages.append(msg)
81+
# append the message to the beginning of the list to preserve order
82+
processed_messages = [msg] + processed_messages
7983
processed_messages_tokens += msg_tokens
80-
81-
total_tokens = 0
82-
for msg in messages:
83-
total_tokens += token_count_utils.count_token(msg["content"])
84-
84+
if system_message:
85+
processed_messages.insert(0, system_message)
86+
# Optionally, log the number of truncated messages and tokens if needed
8587
num_truncated = len(messages) - len(processed_messages)
88+
8689
if num_truncated > 0 or total_tokens > processed_messages_tokens:
87-
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
88-
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
90+
print(
91+
colored(
92+
f"Truncated {num_truncated} messages. Reduced from {len(messages)} to {len(processed_messages)}.",
93+
"yellow",
94+
)
95+
)
96+
print(
97+
colored(
98+
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
99+
"yellow",
100+
)
101+
)
89102
return processed_messages
90103

91104

92-
def truncate_str_to_tokens(text: str, max_tokens: int) -> str:
93-
"""
94-
Truncate a string so that number of tokens in less than max_tokens.
105+
def truncate_str_to_tokens(text: str, max_tokens: int, model: str = "gpt-3.5-turbo-0613") -> str:
106+
"""Truncate a string so that the number of tokens is less than or equal to max_tokens using tiktoken.
95107
96108
Args:
97-
content: String to process.
98-
max_tokens: Maximum number of tokens to keep.
109+
text: The string to truncate.
110+
max_tokens: The maximum number of tokens to keep.
111+
model: The target OpenAI model for tokenization alignment.
99112
100113
Returns:
101-
Truncated string.
114+
The truncated string.
102115
"""
103-
truncated_string = ""
104-
for char in text:
105-
truncated_string += char
106-
if token_count_utils.count_token(truncated_string) == max_tokens:
107-
break
108-
return truncated_string
116+
117+
encoding = tiktoken.encoding_for_model(model) # Get the appropriate tokenizer
118+
119+
encoded_tokens = encoding.encode(text)
120+
truncated_tokens = encoded_tokens[:max_tokens]
121+
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
122+
123+
return truncated_text

0 commit comments

Comments
 (0)