1
1
import copy
2
+ import json
2
3
import sys
3
4
from typing import Any , Dict , List , Optional , Protocol , Tuple , Union
4
5
5
6
import tiktoken
6
7
from termcolor import colored
7
8
8
9
from autogen import token_count_utils
10
+ from autogen .cache import AbstractCache , Cache
11
+
12
+ from .text_compressors import LLMLingua , TextCompressor
9
13
10
14
11
15
class MessageTransform (Protocol ):
@@ -156,7 +160,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
156
160
assert self ._min_tokens is not None
157
161
158
162
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
159
- if not self . _are_min_tokens_reached (messages ):
163
+ if not _min_tokens_reached (messages , self . _min_tokens ):
160
164
return messages
161
165
162
166
temp_messages = copy .deepcopy (messages )
@@ -205,19 +209,6 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
205
209
return logs_str , True
206
210
return "No tokens were truncated." , False
207
211
208
- def _are_min_tokens_reached (self , messages : List [Dict ]) -> bool :
209
- """
210
- Returns True if no minimum tokens restrictions are applied.
211
-
212
- Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
213
- or no minimum tokens threshold is set.
214
- """
215
- if not self ._min_tokens :
216
- return True
217
-
218
- messages_tokens = sum (_count_tokens (msg ["content" ]) for msg in messages if "content" in msg )
219
- return messages_tokens >= self ._min_tokens
220
-
221
212
def _truncate_str_to_tokens (self , contents : Union [str , List ], n_tokens : int ) -> Union [str , List ]:
222
213
if isinstance (contents , str ):
223
214
return self ._truncate_tokens (contents , n_tokens )
@@ -268,7 +259,7 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int
268
259
269
260
return max_tokens if max_tokens is not None else sys .maxsize
270
261
271
- def _validate_min_tokens (self , min_tokens : int , max_tokens : int ) -> int :
262
+ def _validate_min_tokens (self , min_tokens : Optional [ int ] , max_tokens : Optional [ int ] ) -> int :
272
263
if min_tokens is None :
273
264
return 0
274
265
if min_tokens < 0 :
@@ -278,6 +269,154 @@ def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
278
269
return min_tokens
279
270
280
271
272
+ class TextMessageCompressor :
273
+ """A transform for compressing text messages in a conversation history.
274
+
275
+ It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
276
+ processing and response generation by downstream models.
277
+ """
278
+
279
+ def __init__ (
280
+ self ,
281
+ text_compressor : Optional [TextCompressor ] = None ,
282
+ min_tokens : Optional [int ] = None ,
283
+ compression_params : Dict = dict (),
284
+ cache : Optional [AbstractCache ] = Cache .disk (),
285
+ ):
286
+ """
287
+ Args:
288
+ text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
289
+ protocol. If None, it defaults to LLMLingua.
290
+ min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
291
+ than or equal to 0 if not None. If None, no threshold-based compression is applied.
292
+ compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
293
+ dictionary.
294
+ cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
295
+ If None, no caching will be used.
296
+ """
297
+
298
+ if text_compressor is None :
299
+ text_compressor = LLMLingua ()
300
+
301
+ self ._validate_min_tokens (min_tokens )
302
+
303
+ self ._text_compressor = text_compressor
304
+ self ._min_tokens = min_tokens
305
+ self ._compression_args = compression_params
306
+ self ._cache = cache
307
+
308
+ # Optimizing savings calculations to optimize log generation
309
+ self ._recent_tokens_savings = 0
310
+
311
+ def apply_transform (self , messages : List [Dict ]) -> List [Dict ]:
312
+ """Applies compression to messages in a conversation history based on the specified configuration.
313
+
314
+ The function processes each message according to the `compression_args` and `min_tokens` settings, applying
315
+ the specified compression configuration and returning a new list of messages with reduced token counts
316
+ where possible.
317
+
318
+ Args:
319
+ messages (List[Dict]): A list of message dictionaries to be compressed.
320
+
321
+ Returns:
322
+ List[Dict]: A list of dictionaries with the message content compressed according to the configured
323
+ method and scope.
324
+ """
325
+ # Make sure there is at least one message
326
+ if not messages :
327
+ return messages
328
+
329
+ # if the total number of tokens in the messages is less than the min_tokens, return the messages as is
330
+ if not _min_tokens_reached (messages , self ._min_tokens ):
331
+ return messages
332
+
333
+ total_savings = 0
334
+ processed_messages = messages .copy ()
335
+ for message in processed_messages :
336
+ # Some messages may not have content.
337
+ if not isinstance (message .get ("content" ), (str , list )):
338
+ continue
339
+
340
+ if _is_content_text_empty (message ["content" ]):
341
+ continue
342
+
343
+ cached_content = self ._cache_get (message ["content" ])
344
+ if cached_content is not None :
345
+ savings , compressed_content = cached_content
346
+ else :
347
+ savings , compressed_content = self ._compress (message ["content" ])
348
+
349
+ self ._cache_set (message ["content" ], compressed_content , savings )
350
+
351
+ message ["content" ] = compressed_content
352
+ total_savings += savings
353
+
354
+ self ._recent_tokens_savings = total_savings
355
+ return processed_messages
356
+
357
+ def get_logs (self , pre_transform_messages : List [Dict ], post_transform_messages : List [Dict ]) -> Tuple [str , bool ]:
358
+ if self ._recent_tokens_savings > 0 :
359
+ return f"{ self ._recent_tokens_savings } tokens saved with text compression." , True
360
+ else :
361
+ return "No tokens saved with text compression." , False
362
+
363
+ def _compress (self , content : Union [str , List [Dict ]]) -> Tuple [int , Union [str , List [Dict ]]]:
364
+ """Compresses the given text or multimodal content using the specified compression method."""
365
+ if isinstance (content , str ):
366
+ return self ._compress_text (content )
367
+ elif isinstance (content , list ):
368
+ return self ._compress_multimodal (content )
369
+ else :
370
+ return 0 , content
371
+
372
+ def _compress_multimodal (self , content : List [Dict ]) -> Tuple [int , List [Dict ]]:
373
+ tokens_saved = 0
374
+ for msg in content :
375
+ if "text" in msg :
376
+ savings , msg ["text" ] = self ._compress_text (msg ["text" ])
377
+ tokens_saved += savings
378
+ return tokens_saved , content
379
+
380
+ def _compress_text (self , text : str ) -> Tuple [int , str ]:
381
+ """Compresses the given text using the specified compression method."""
382
+ compressed_text = self ._text_compressor .compress_text (text , ** self ._compression_args )
383
+
384
+ savings = 0
385
+ if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text :
386
+ savings = compressed_text ["origin_tokens" ] - compressed_text ["compressed_tokens" ]
387
+
388
+ return savings , compressed_text ["compressed_prompt" ]
389
+
390
+ def _cache_get (self , content : Union [str , List [Dict ]]) -> Optional [Tuple [int , Union [str , List [Dict ]]]]:
391
+ if self ._cache :
392
+ cached_value = self ._cache .get (self ._cache_key (content ))
393
+ if cached_value :
394
+ return cached_value
395
+
396
+ def _cache_set (
397
+ self , content : Union [str , List [Dict ]], compressed_content : Union [str , List [Dict ]], tokens_saved : int
398
+ ):
399
+ if self ._cache :
400
+ value = (tokens_saved , json .dumps (compressed_content ))
401
+ self ._cache .set (self ._cache_key (content ), value )
402
+
403
+ def _cache_key (self , content : Union [str , List [Dict ]]) -> str :
404
+ return f"{ json .dumps (content )} _{ self ._min_tokens } "
405
+
406
+ def _validate_min_tokens (self , min_tokens : Optional [int ]):
407
+ if min_tokens is not None and min_tokens <= 0 :
408
+ raise ValueError ("min_tokens must be greater than 0 or None" )
409
+
410
+
411
+ def _min_tokens_reached (messages : List [Dict ], min_tokens : Optional [int ]) -> bool :
412
+ """Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
413
+ if not min_tokens :
414
+ return True
415
+
416
+ messages_tokens = sum (_count_tokens (msg ["content" ]) for msg in messages if "content" in msg )
417
+ return messages_tokens >= min_tokens
418
+
419
+
281
420
def _count_tokens (content : Union [str , List [Dict [str , Any ]]]) -> int :
282
421
token_count = 0
283
422
if isinstance (content , str ):
@@ -286,3 +425,12 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
286
425
for item in content :
287
426
token_count += _count_tokens (item .get ("text" , "" ))
288
427
return token_count
428
+
429
+
430
+ def _is_content_text_empty (content : Union [str , List [Dict [str , Any ]]]) -> bool :
431
+ if isinstance (content , str ):
432
+ return content == ""
433
+ elif isinstance (content , list ):
434
+ return all (_is_content_text_empty (item .get ("text" , "" )) for item in content )
435
+ else :
436
+ return False
0 commit comments