@@ -85,7 +85,8 @@ class MessageTokenLimiter:
85
85
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
86
86
and other types of content, only the text content is truncated.
87
87
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
88
- exceeds this limit, the current message being processed as well as any remaining messages are discarded.
88
+ exceeds this limit, the current message being processed get truncated to meet the total token count and any
89
+ remaining messages get discarded.
89
90
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
90
91
original message order.
91
92
"""
@@ -128,13 +129,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
128
129
total_tokens = sum (_count_tokens (msg ["content" ]) for msg in temp_messages )
129
130
130
131
for msg in reversed (temp_messages ):
131
- msg ["content" ] = self ._truncate_str_to_tokens (msg ["content" ])
132
- msg_tokens = _count_tokens (msg ["content" ])
132
+ expected_tokens_remained = self ._max_tokens - processed_messages_tokens - self ._max_tokens_per_message
133
133
134
- # If adding this message would exceed the token limit, discard it and all remaining messages
135
- if processed_messages_tokens + msg_tokens > self ._max_tokens :
134
+ # If adding this message would exceed the token limit, truncate the last message to meet the total token
135
+ # limit and discard all remaining messages
136
+ if expected_tokens_remained < 0 :
137
+ msg ["content" ] = self ._truncate_str_to_tokens (
138
+ msg ["content" ], self ._max_tokens - processed_messages_tokens
139
+ )
140
+ processed_messages .insert (0 , msg )
136
141
break
137
142
143
+ msg ["content" ] = self ._truncate_str_to_tokens (msg ["content" ], self ._max_tokens_per_message )
144
+ msg_tokens = _count_tokens (msg ["content" ])
145
+
138
146
# prepend the message to the list to preserve order
139
147
processed_messages_tokens += msg_tokens
140
148
processed_messages .insert (0 , msg )
@@ -149,30 +157,30 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
149
157
150
158
return processed_messages
151
159
152
- def _truncate_str_to_tokens (self , contents : Union [str , List ]) -> Union [str , List ]:
160
+ def _truncate_str_to_tokens (self , contents : Union [str , List ], n_tokens : int ) -> Union [str , List ]:
153
161
if isinstance (contents , str ):
154
- return self ._truncate_tokens (contents )
162
+ return self ._truncate_tokens (contents , n_tokens )
155
163
elif isinstance (contents , list ):
156
- return self ._truncate_multimodal_text (contents )
164
+ return self ._truncate_multimodal_text (contents , n_tokens )
157
165
else :
158
166
raise ValueError (f"Contents must be a string or a list of dictionaries. Received type: { type (contents )} " )
159
167
160
- def _truncate_multimodal_text (self , contents : List [Dict [str , Any ]]) -> List [Dict [str , Any ]]:
168
+ def _truncate_multimodal_text (self , contents : List [Dict [str , Any ]], n_tokens : int ) -> List [Dict [str , Any ]]:
161
169
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
162
170
tmp_contents = []
163
171
for content in contents :
164
172
if content ["type" ] == "text" :
165
- truncated_text = self ._truncate_tokens (content ["text" ])
173
+ truncated_text = self ._truncate_tokens (content ["text" ], n_tokens )
166
174
tmp_contents .append ({"type" : "text" , "text" : truncated_text })
167
175
else :
168
176
tmp_contents .append (content )
169
177
return tmp_contents
170
178
171
- def _truncate_tokens (self , text : str ) -> str :
179
+ def _truncate_tokens (self , text : str , n_tokens : int ) -> str :
172
180
encoding = tiktoken .encoding_for_model (self ._model ) # Get the appropriate tokenizer
173
181
174
182
encoded_tokens = encoding .encode (text )
175
- truncated_tokens = encoded_tokens [: self . _max_tokens_per_message ]
183
+ truncated_tokens = encoded_tokens [:n_tokens ]
176
184
truncated_text = encoding .decode (truncated_tokens ) # Decode back to text
177
185
178
186
return truncated_text
0 commit comments