Skip to content

Commit 372ac1e

Browse files
authored
Text Compression Transform (#2225)
* adds implementation * handles optional import * cleanup * updates github workflows * skip test if dependencies not installed * skip test if dependencies not installed * use cpu * skip openai * unskip openai * adds protocol * better docstr * minor fixes * updates optional dependencies docs * wip * update docstrings * wip * adds back llmlingua requirement * finalized protocol * improve docstr * guide complete * improve docstr * fix FAQ * added cache support * improve cache key * cache key fix + faq fix * improve docs * improve guide * args -> params * spelling
1 parent 5a3a8a5 commit 372ac1e

File tree

10 files changed

+503
-33
lines changed

10 files changed

+503
-33
lines changed

.github/workflows/contrib-tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ jobs:
400400
pip install pytest-cov>=5
401401
- name: Install packages and dependencies for Transform Messages
402402
run: |
403-
pip install -e .
403+
pip install -e '.[long-context]'
404404
- name: Set AUTOGEN_USE_DOCKER based on OS
405405
shell: bash
406406
run: |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any, Dict, Optional, Protocol
2+
3+
IMPORT_ERROR: Optional[Exception] = None
4+
try:
5+
import llmlingua
6+
except ImportError:
7+
IMPORT_ERROR = ImportError(
8+
"LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
9+
)
10+
PromptCompressor = object
11+
else:
12+
from llmlingua import PromptCompressor
13+
14+
15+
class TextCompressor(Protocol):
16+
"""Defines a protocol for text compression to optimize agent interactions."""
17+
18+
def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
19+
"""This method takes a string as input and returns a dictionary containing the compressed text and other
20+
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
21+
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
22+
"""
23+
...
24+
25+
26+
class LLMLingua:
27+
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
28+
29+
NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
30+
and the specific configurations used for the PromptCompressor.
31+
"""
32+
33+
def __init__(
34+
self,
35+
prompt_compressor_kwargs: Dict = dict(
36+
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
37+
use_llmlingua2=True,
38+
device_map="cpu",
39+
),
40+
structured_compression: bool = False,
41+
) -> None:
42+
"""
43+
Args:
44+
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
45+
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
46+
use_llmlingua2 set to True, and device_map set to "cpu".
47+
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
48+
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
49+
is used. Defaults to False.
50+
dictionary.
51+
52+
Raises:
53+
ImportError: If the llmlingua library is not installed.
54+
"""
55+
if IMPORT_ERROR:
56+
raise IMPORT_ERROR
57+
58+
self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
59+
60+
assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
61+
self._compression_method = (
62+
self._prompt_compressor.structured_compress_prompt
63+
if structured_compression
64+
else self._prompt_compressor.compress_prompt
65+
)
66+
67+
def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
68+
return self._compression_method([text], **compression_params)

autogen/agentchat/contrib/capabilities/transforms.py

+163-15
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import copy
2+
import json
23
import sys
34
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
45

56
import tiktoken
67
from termcolor import colored
78

89
from autogen import token_count_utils
10+
from autogen.cache import AbstractCache, Cache
11+
12+
from .text_compressors import LLMLingua, TextCompressor
913

1014

1115
class MessageTransform(Protocol):
@@ -156,7 +160,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
156160
assert self._min_tokens is not None
157161

158162
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
159-
if not self._are_min_tokens_reached(messages):
163+
if not _min_tokens_reached(messages, self._min_tokens):
160164
return messages
161165

162166
temp_messages = copy.deepcopy(messages)
@@ -205,19 +209,6 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
205209
return logs_str, True
206210
return "No tokens were truncated.", False
207211

208-
def _are_min_tokens_reached(self, messages: List[Dict]) -> bool:
209-
"""
210-
Returns True if no minimum tokens restrictions are applied.
211-
212-
Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
213-
or no minimum tokens threshold is set.
214-
"""
215-
if not self._min_tokens:
216-
return True
217-
218-
messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
219-
return messages_tokens >= self._min_tokens
220-
221212
def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
222213
if isinstance(contents, str):
223214
return self._truncate_tokens(contents, n_tokens)
@@ -268,7 +259,7 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int
268259

269260
return max_tokens if max_tokens is not None else sys.maxsize
270261

271-
def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
262+
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
272263
if min_tokens is None:
273264
return 0
274265
if min_tokens < 0:
@@ -278,6 +269,154 @@ def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
278269
return min_tokens
279270

280271

272+
class TextMessageCompressor:
273+
"""A transform for compressing text messages in a conversation history.
274+
275+
It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
276+
processing and response generation by downstream models.
277+
"""
278+
279+
def __init__(
280+
self,
281+
text_compressor: Optional[TextCompressor] = None,
282+
min_tokens: Optional[int] = None,
283+
compression_params: Dict = dict(),
284+
cache: Optional[AbstractCache] = Cache.disk(),
285+
):
286+
"""
287+
Args:
288+
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
289+
protocol. If None, it defaults to LLMLingua.
290+
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
291+
than or equal to 0 if not None. If None, no threshold-based compression is applied.
292+
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
293+
dictionary.
294+
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
295+
If None, no caching will be used.
296+
"""
297+
298+
if text_compressor is None:
299+
text_compressor = LLMLingua()
300+
301+
self._validate_min_tokens(min_tokens)
302+
303+
self._text_compressor = text_compressor
304+
self._min_tokens = min_tokens
305+
self._compression_args = compression_params
306+
self._cache = cache
307+
308+
# Optimizing savings calculations to optimize log generation
309+
self._recent_tokens_savings = 0
310+
311+
def apply_transform(self, messages: List[Dict]) -> List[Dict]:
312+
"""Applies compression to messages in a conversation history based on the specified configuration.
313+
314+
The function processes each message according to the `compression_args` and `min_tokens` settings, applying
315+
the specified compression configuration and returning a new list of messages with reduced token counts
316+
where possible.
317+
318+
Args:
319+
messages (List[Dict]): A list of message dictionaries to be compressed.
320+
321+
Returns:
322+
List[Dict]: A list of dictionaries with the message content compressed according to the configured
323+
method and scope.
324+
"""
325+
# Make sure there is at least one message
326+
if not messages:
327+
return messages
328+
329+
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
330+
if not _min_tokens_reached(messages, self._min_tokens):
331+
return messages
332+
333+
total_savings = 0
334+
processed_messages = messages.copy()
335+
for message in processed_messages:
336+
# Some messages may not have content.
337+
if not isinstance(message.get("content"), (str, list)):
338+
continue
339+
340+
if _is_content_text_empty(message["content"]):
341+
continue
342+
343+
cached_content = self._cache_get(message["content"])
344+
if cached_content is not None:
345+
savings, compressed_content = cached_content
346+
else:
347+
savings, compressed_content = self._compress(message["content"])
348+
349+
self._cache_set(message["content"], compressed_content, savings)
350+
351+
message["content"] = compressed_content
352+
total_savings += savings
353+
354+
self._recent_tokens_savings = total_savings
355+
return processed_messages
356+
357+
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
358+
if self._recent_tokens_savings > 0:
359+
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
360+
else:
361+
return "No tokens saved with text compression.", False
362+
363+
def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
364+
"""Compresses the given text or multimodal content using the specified compression method."""
365+
if isinstance(content, str):
366+
return self._compress_text(content)
367+
elif isinstance(content, list):
368+
return self._compress_multimodal(content)
369+
else:
370+
return 0, content
371+
372+
def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
373+
tokens_saved = 0
374+
for msg in content:
375+
if "text" in msg:
376+
savings, msg["text"] = self._compress_text(msg["text"])
377+
tokens_saved += savings
378+
return tokens_saved, content
379+
380+
def _compress_text(self, text: str) -> Tuple[int, str]:
381+
"""Compresses the given text using the specified compression method."""
382+
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
383+
384+
savings = 0
385+
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
386+
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
387+
388+
return savings, compressed_text["compressed_prompt"]
389+
390+
def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
391+
if self._cache:
392+
cached_value = self._cache.get(self._cache_key(content))
393+
if cached_value:
394+
return cached_value
395+
396+
def _cache_set(
397+
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
398+
):
399+
if self._cache:
400+
value = (tokens_saved, json.dumps(compressed_content))
401+
self._cache.set(self._cache_key(content), value)
402+
403+
def _cache_key(self, content: Union[str, List[Dict]]) -> str:
404+
return f"{json.dumps(content)}_{self._min_tokens}"
405+
406+
def _validate_min_tokens(self, min_tokens: Optional[int]):
407+
if min_tokens is not None and min_tokens <= 0:
408+
raise ValueError("min_tokens must be greater than 0 or None")
409+
410+
411+
def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
412+
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
413+
if not min_tokens:
414+
return True
415+
416+
messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
417+
return messages_tokens >= min_tokens
418+
419+
281420
def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
282421
token_count = 0
283422
if isinstance(content, str):
@@ -286,3 +425,12 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
286425
for item in content:
287426
token_count += _count_tokens(item.get("text", ""))
288427
return token_count
428+
429+
430+
def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
431+
if isinstance(content, str):
432+
return content == ""
433+
elif isinstance(content, list):
434+
return all(_is_content_text_empty(item.get("text", "")) for item in content)
435+
else:
436+
return False

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
"websockets": ["websockets>=12.0,<13"],
8080
"jupyter-executor": jupyter_executor,
8181
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
82+
"long-context": ["llmlingua<0.3"],
8283
}
8384

8485
setuptools.setup(

0 commit comments

Comments
 (0)