Skip to content

Commit 4583572

Browse files
authored
[Fix] Improves Token Limiter (microsoft#2350)
* improves token limiter * improve docstr * rename arg
1 parent f0408ae commit 4583572

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

autogen/agentchat/contrib/capabilities/transforms.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ class MessageTokenLimiter:
8585
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
8686
and other types of content, only the text content is truncated.
8787
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
88-
exceeds this limit, the current message being processed as well as any remaining messages are discarded.
88+
exceeds this limit, the current message being processed get truncated to meet the total token count and any
89+
remaining messages get discarded.
8990
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
9091
original message order.
9192
"""
@@ -128,13 +129,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
128129
total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages)
129130

130131
for msg in reversed(temp_messages):
131-
msg["content"] = self._truncate_str_to_tokens(msg["content"])
132-
msg_tokens = _count_tokens(msg["content"])
132+
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
133133

134-
# If adding this message would exceed the token limit, discard it and all remaining messages
135-
if processed_messages_tokens + msg_tokens > self._max_tokens:
134+
# If adding this message would exceed the token limit, truncate the last message to meet the total token
135+
# limit and discard all remaining messages
136+
if expected_tokens_remained < 0:
137+
msg["content"] = self._truncate_str_to_tokens(
138+
msg["content"], self._max_tokens - processed_messages_tokens
139+
)
140+
processed_messages.insert(0, msg)
136141
break
137142

143+
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
144+
msg_tokens = _count_tokens(msg["content"])
145+
138146
# prepend the message to the list to preserve order
139147
processed_messages_tokens += msg_tokens
140148
processed_messages.insert(0, msg)
@@ -149,30 +157,30 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
149157

150158
return processed_messages
151159

152-
def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]:
160+
def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
153161
if isinstance(contents, str):
154-
return self._truncate_tokens(contents)
162+
return self._truncate_tokens(contents, n_tokens)
155163
elif isinstance(contents, list):
156-
return self._truncate_multimodal_text(contents)
164+
return self._truncate_multimodal_text(contents, n_tokens)
157165
else:
158166
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")
159167

160-
def _truncate_multimodal_text(self, contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
168+
def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]:
161169
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
162170
tmp_contents = []
163171
for content in contents:
164172
if content["type"] == "text":
165-
truncated_text = self._truncate_tokens(content["text"])
173+
truncated_text = self._truncate_tokens(content["text"], n_tokens)
166174
tmp_contents.append({"type": "text", "text": truncated_text})
167175
else:
168176
tmp_contents.append(content)
169177
return tmp_contents
170178

171-
def _truncate_tokens(self, text: str) -> str:
179+
def _truncate_tokens(self, text: str, n_tokens: int) -> str:
172180
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer
173181

174182
encoded_tokens = encoding.encode(text)
175-
truncated_tokens = encoded_tokens[: self._max_tokens_per_message]
183+
truncated_tokens = encoded_tokens[:n_tokens]
176184
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
177185

178186
return truncated_text

0 commit comments

Comments
 (0)