1
1
import copy
2
- import json
3
2
import sys
4
3
from typing import Any , Dict , List , Optional , Protocol , Tuple , Union
5
4
8
7
9
8
from autogen import token_count_utils
10
9
from autogen .cache import AbstractCache , Cache
11
- from autogen .oai . openai_utils import filter_config
10
+ from autogen .types import MessageContentType
12
11
12
+ from . import transforms_util
13
13
from .text_compressors import LLMLingua , TextCompressor
14
14
15
15
@@ -169,7 +169,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
169
169
assert self ._min_tokens is not None
170
170
171
171
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
172
- if not _min_tokens_reached (messages , self ._min_tokens ):
172
+ if not transforms_util . min_tokens_reached (messages , self ._min_tokens ):
173
173
return messages
174
174
175
175
temp_messages = copy .deepcopy (messages )
@@ -178,13 +178,13 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
178
178
179
179
for msg in reversed (temp_messages ):
180
180
# Some messages may not have content.
181
- if not _is_content_right_type (msg .get ("content" )):
181
+ if not transforms_util . is_content_right_type (msg .get ("content" )):
182
182
processed_messages .insert (0 , msg )
183
183
continue
184
184
185
- if not _should_transform_message (msg , self ._filter_dict , self ._exclude_filter ):
185
+ if not transforms_util . should_transform_message (msg , self ._filter_dict , self ._exclude_filter ):
186
186
processed_messages .insert (0 , msg )
187
- processed_messages_tokens += _count_tokens (msg ["content" ])
187
+ processed_messages_tokens += transforms_util . count_text_tokens (msg ["content" ])
188
188
continue
189
189
190
190
expected_tokens_remained = self ._max_tokens - processed_messages_tokens - self ._max_tokens_per_message
@@ -199,7 +199,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
199
199
break
200
200
201
201
msg ["content" ] = self ._truncate_str_to_tokens (msg ["content" ], self ._max_tokens_per_message )
202
- msg_tokens = _count_tokens (msg ["content" ])
202
+ msg_tokens = transforms_util . count_text_tokens (msg ["content" ])
203
203
204
204
# prepend the message to the list to preserve order
205
205
processed_messages_tokens += msg_tokens
@@ -209,10 +209,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
209
209
210
210
def get_logs (self , pre_transform_messages : List [Dict ], post_transform_messages : List [Dict ]) -> Tuple [str , bool ]:
211
211
pre_transform_messages_tokens = sum (
212
- _count_tokens (msg ["content" ]) for msg in pre_transform_messages if "content" in msg
212
+ transforms_util . count_text_tokens (msg ["content" ]) for msg in pre_transform_messages if "content" in msg
213
213
)
214
214
post_transform_messages_tokens = sum (
215
- _count_tokens (msg ["content" ]) for msg in post_transform_messages if "content" in msg
215
+ transforms_util . count_text_tokens (msg ["content" ]) for msg in post_transform_messages if "content" in msg
216
216
)
217
217
218
218
if post_transform_messages_tokens < pre_transform_messages_tokens :
@@ -349,31 +349,32 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
349
349
return messages
350
350
351
351
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
352
- if not _min_tokens_reached (messages , self ._min_tokens ):
352
+ if not transforms_util . min_tokens_reached (messages , self ._min_tokens ):
353
353
return messages
354
354
355
355
total_savings = 0
356
356
processed_messages = messages .copy ()
357
357
for message in processed_messages :
358
358
# Some messages may not have content.
359
- if not _is_content_right_type (message .get ("content" )):
359
+ if not transforms_util . is_content_right_type (message .get ("content" )):
360
360
continue
361
361
362
- if not _should_transform_message (message , self ._filter_dict , self ._exclude_filter ):
362
+ if not transforms_util . should_transform_message (message , self ._filter_dict , self ._exclude_filter ):
363
363
continue
364
364
365
- if _is_content_text_empty (message ["content" ]):
365
+ if transforms_util . is_content_text_empty (message ["content" ]):
366
366
continue
367
367
368
- cached_content = self ._cache_get (message ["content" ])
368
+ cache_key = transforms_util .cache_key (message ["content" ], self ._min_tokens )
369
+ cached_content = transforms_util .cache_content_get (self ._cache , cache_key )
369
370
if cached_content is not None :
370
- savings , compressed_content = cached_content
371
+ message [ "content" ], savings = cached_content
371
372
else :
372
- savings , compressed_content = self ._compress (message ["content" ])
373
+ message [ "content" ], savings = self ._compress (message ["content" ])
373
374
374
- self . _cache_set ( message ["content" ], compressed_content , savings )
375
+ transforms_util . cache_content_set ( self . _cache , cache_key , message ["content" ], savings )
375
376
376
- message [ "content" ] = compressed_content
377
+ assert isinstance ( savings , int )
377
378
total_savings += savings
378
379
379
380
self ._recent_tokens_savings = total_savings
@@ -385,88 +386,38 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
385
386
else :
386
387
return "No tokens saved with text compression." , False
387
388
388
- def _compress (self , content : Union [ str , List [ Dict ]] ) -> Tuple [int , Union [ str , List [ Dict ]] ]:
389
+ def _compress (self , content : MessageContentType ) -> Tuple [MessageContentType , int ]:
389
390
"""Compresses the given text or multimodal content using the specified compression method."""
390
391
if isinstance (content , str ):
391
392
return self ._compress_text (content )
392
393
elif isinstance (content , list ):
393
394
return self ._compress_multimodal (content )
394
395
else :
395
- return 0 , content
396
+ return content , 0
396
397
397
- def _compress_multimodal (self , content : List [ Dict ] ) -> Tuple [int , List [ Dict ] ]:
398
+ def _compress_multimodal (self , content : MessageContentType ) -> Tuple [MessageContentType , int ]:
398
399
tokens_saved = 0
399
- for msg in content :
400
- if "text" in msg :
401
- savings , msg ["text" ] = self ._compress_text (msg ["text" ])
400
+ for item in content :
401
+ if isinstance (item , dict ) and "text" in item :
402
+ item ["text" ], savings = self ._compress_text (item ["text" ])
403
+ tokens_saved += savings
404
+
405
+ elif isinstance (item , str ):
406
+ item , savings = self ._compress_text (item )
402
407
tokens_saved += savings
403
- return tokens_saved , content
404
408
405
- def _compress_text (self , text : str ) -> Tuple [int , str ]:
409
+ return content , tokens_saved
410
+
411
+ def _compress_text (self , text : str ) -> Tuple [str , int ]:
406
412
"""Compresses the given text using the specified compression method."""
407
413
compressed_text = self ._text_compressor .compress_text (text , ** self ._compression_args )
408
414
409
415
savings = 0
410
416
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text :
411
417
savings = compressed_text ["origin_tokens" ] - compressed_text ["compressed_tokens" ]
412
418
413
- return savings , compressed_text ["compressed_prompt" ]
414
-
415
- def _cache_get (self , content : Union [str , List [Dict ]]) -> Optional [Tuple [int , Union [str , List [Dict ]]]]:
416
- if self ._cache :
417
- cached_value = self ._cache .get (self ._cache_key (content ))
418
- if cached_value :
419
- return cached_value
420
-
421
- def _cache_set (
422
- self , content : Union [str , List [Dict ]], compressed_content : Union [str , List [Dict ]], tokens_saved : int
423
- ):
424
- if self ._cache :
425
- value = (tokens_saved , compressed_content )
426
- self ._cache .set (self ._cache_key (content ), value )
427
-
428
- def _cache_key (self , content : Union [str , List [Dict ]]) -> str :
429
- return f"{ json .dumps (content )} _{ self ._min_tokens } "
419
+ return compressed_text ["compressed_prompt" ], savings
430
420
431
421
def _validate_min_tokens (self , min_tokens : Optional [int ]):
432
422
if min_tokens is not None and min_tokens <= 0 :
433
423
raise ValueError ("min_tokens must be greater than 0 or None" )
434
-
435
-
436
- def _min_tokens_reached (messages : List [Dict ], min_tokens : Optional [int ]) -> bool :
437
- """Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
438
- if not min_tokens :
439
- return True
440
-
441
- messages_tokens = sum (_count_tokens (msg ["content" ]) for msg in messages if "content" in msg )
442
- return messages_tokens >= min_tokens
443
-
444
-
445
- def _count_tokens (content : Union [str , List [Dict [str , Any ]]]) -> int :
446
- token_count = 0
447
- if isinstance (content , str ):
448
- token_count = token_count_utils .count_token (content )
449
- elif isinstance (content , list ):
450
- for item in content :
451
- token_count += _count_tokens (item .get ("text" , "" ))
452
- return token_count
453
-
454
-
455
- def _is_content_right_type (content : Any ) -> bool :
456
- return isinstance (content , (str , list ))
457
-
458
-
459
- def _is_content_text_empty (content : Union [str , List [Dict [str , Any ]]]) -> bool :
460
- if isinstance (content , str ):
461
- return content == ""
462
- elif isinstance (content , list ):
463
- return all (_is_content_text_empty (item .get ("text" , "" )) for item in content )
464
- else :
465
- return False
466
-
467
-
468
- def _should_transform_message (message : Dict [str , Any ], filter_dict : Optional [Dict [str , Any ]], exclude : bool ) -> bool :
469
- if not filter_dict :
470
- return True
471
-
472
- return len (filter_config ([message ], filter_dict , exclude )) > 0
0 commit comments