Skip to content

Text Compression Transform #2225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
472d0f8
adds implementation
WaelKarkoub Mar 31, 2024
28fa3f5
handles optional import
WaelKarkoub Mar 31, 2024
12f95f2
cleanup
WaelKarkoub Mar 31, 2024
50607eb
updates github workflows
WaelKarkoub Mar 31, 2024
de3e655
skip test if dependencies not installed
WaelKarkoub Mar 31, 2024
7ee0a55
skip test if dependencies not installed
WaelKarkoub Mar 31, 2024
aa572e3
use cpu
WaelKarkoub Mar 31, 2024
62cfd33
Merge branch 'main' into llm-lingua-transform
WaelKarkoub Mar 31, 2024
36388a9
skip openai
WaelKarkoub Mar 31, 2024
a182c34
Merge branch 'main' into llm-lingua-transform
WaelKarkoub Mar 31, 2024
51a3085
unskip openai
WaelKarkoub Apr 1, 2024
71d0447
Merge branch 'main' into llm-lingua-transform
WaelKarkoub Apr 2, 2024
b3eb1b3
Merge branch 'main' into llm-lingua-transform
WaelKarkoub Apr 4, 2024
32eb666
adds protocol
WaelKarkoub Apr 4, 2024
5a9d672
better docstr
WaelKarkoub Apr 4, 2024
93bc0eb
minor fixes
WaelKarkoub Apr 4, 2024
b1c1f37
Merge branch 'main' into llm-lingua-transform
WaelKarkoub Apr 5, 2024
268f3e8
updates optional dependencies docs
WaelKarkoub Apr 5, 2024
414b518
Merge branch 'main' into llm-lingua-transform
WaelKarkoub Apr 5, 2024
142ba37
Merge branch 'main' into llm-lingua-transform
WaelKarkoub Apr 7, 2024
6ba0526
wip
WaelKarkoub Apr 7, 2024
2cb5761
update docstrings
WaelKarkoub Apr 7, 2024
73b8092
adds compression test
WaelKarkoub Apr 22, 2024
3584aea
wip
WaelKarkoub Apr 23, 2024
d1e55c4
Merge branch 'main' into llm-lingua-transform
WaelKarkoub Apr 29, 2024
d3439dc
adds back llmlingua requirement
WaelKarkoub Apr 29, 2024
965c459
Merge branch 'main' into llm-lingua-transform
WaelKarkoub May 1, 2024
db73843
finalized protocol
WaelKarkoub May 1, 2024
2b47222
improve docstr
WaelKarkoub May 1, 2024
22b34b1
guide complete
WaelKarkoub May 1, 2024
4432e61
improve docstr
WaelKarkoub May 1, 2024
a28ac00
Merge branch 'main' into llm-lingua-transform
WaelKarkoub May 2, 2024
1fd533e
Merge branch 'main' into llm-lingua-transform
WaelKarkoub May 3, 2024
62afdfb
fix FAQ
WaelKarkoub May 3, 2024
fa94e39
added cache support
WaelKarkoub May 4, 2024
9cf65ed
improve cache key
WaelKarkoub May 4, 2024
8390dec
Merge branch 'main' into llm-lingua-transform
WaelKarkoub May 4, 2024
a3fb568
cache key fix + faq fix
WaelKarkoub May 4, 2024
d46fccf
improve docs
WaelKarkoub May 5, 2024
1cab419
Merge branch 'main' into llm-lingua-transform
WaelKarkoub May 5, 2024
8a9fc96
improve guide
WaelKarkoub May 5, 2024
ec6fe57
args -> params
WaelKarkoub May 5, 2024
98cb736
spelling
WaelKarkoub May 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ jobs:
pip install pytest-cov>=5
- name: Install packages and dependencies for Transform Messages
run: |
pip install -e .
pip install -e '.[long-context]'
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
Expand Down
68 changes: 68 additions & 0 deletions autogen/agentchat/contrib/capabilities/text_compressors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, Optional, Protocol

IMPORT_ERROR: Optional[Exception] = None
try:
import llmlingua
except ImportError:
IMPORT_ERROR = ImportError(
"LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
)
PromptCompressor = object
else:
from llmlingua import PromptCompressor


class TextCompressor(Protocol):
"""Defines a protocol for text compression to optimize agent interactions."""

def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
"""This method takes a string as input and returns a dictionary containing the compressed text and other
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
"""
...


class LLMLingua:
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.

NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
and the specific configurations used for the PromptCompressor.
"""

def __init__(
self,
prompt_compressor_kwargs: Dict = dict(
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2=True,
device_map="cpu",
),
structured_compression: bool = False,
) -> None:
"""
Args:
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2 set to True, and device_map set to "cpu".
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
is used. Defaults to False.
dictionary.

Raises:
ImportError: If the llmlingua library is not installed.
"""
if IMPORT_ERROR:
raise IMPORT_ERROR

self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)

assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
self._compression_method = (
self._prompt_compressor.structured_compress_prompt
if structured_compression
else self._prompt_compressor.compress_prompt
)

def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
return self._compression_method([text], **compression_params)
178 changes: 163 additions & 15 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import copy
import json
import sys
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

import tiktoken
from termcolor import colored

from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache

from .text_compressors import LLMLingua, TextCompressor


class MessageTransform(Protocol):
Expand Down Expand Up @@ -156,7 +160,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
assert self._min_tokens is not None

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not self._are_min_tokens_reached(messages):
if not _min_tokens_reached(messages, self._min_tokens):
return messages

temp_messages = copy.deepcopy(messages)
Expand Down Expand Up @@ -205,19 +209,6 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
return logs_str, True
return "No tokens were truncated.", False

def _are_min_tokens_reached(self, messages: List[Dict]) -> bool:
"""
Returns True if no minimum tokens restrictions are applied.

Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
or no minimum tokens threshold is set.
"""
if not self._min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= self._min_tokens

def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
Expand Down Expand Up @@ -268,7 +259,7 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int

return max_tokens if max_tokens is not None else sys.maxsize

def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
Expand All @@ -278,6 +269,154 @@ def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
return min_tokens


class TextMessageCompressor:
"""A transform for compressing text messages in a conversation history.

It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
processing and response generation by downstream models.
"""

def __init__(
self,
text_compressor: Optional[TextCompressor] = None,
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
):
"""
Args:
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
protocol. If None, it defaults to LLMLingua.
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
than or equal to 0 if not None. If None, no threshold-based compression is applied.
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
"""

if text_compressor is None:
text_compressor = LLMLingua()

self._validate_min_tokens(min_tokens)

self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._cache = cache

# Optimizing savings calculations to optimize log generation
self._recent_tokens_savings = 0

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies compression to messages in a conversation history based on the specified configuration.

The function processes each message according to the `compression_args` and `min_tokens` settings, applying
the specified compression configuration and returning a new list of messages with reduced token counts
where possible.

Args:
messages (List[Dict]): A list of message dictionaries to be compressed.

Returns:
List[Dict]: A list of dictionaries with the message content compressed according to the configured
method and scope.
"""
# Make sure there is at least one message
if not messages:
return messages

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
return messages

total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not isinstance(message.get("content"), (str, list)):
continue

if _is_content_text_empty(message["content"]):
continue

cached_content = self._cache_get(message["content"])
if cached_content is not None:
savings, compressed_content = cached_content
else:
savings, compressed_content = self._compress(message["content"])

self._cache_set(message["content"], compressed_content, savings)

message["content"] = compressed_content
total_savings += savings

self._recent_tokens_savings = total_savings
return processed_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
if self._recent_tokens_savings > 0:
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
else:
return "No tokens saved with text compression.", False

def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return 0, content

def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
tokens_saved = 0
for msg in content:
if "text" in msg:
savings, msg["text"] = self._compress_text(msg["text"])
tokens_saved += savings
return tokens_saved, content

def _compress_text(self, text: str) -> Tuple[int, str]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)

savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]

return savings, compressed_text["compressed_prompt"]

def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
if self._cache:
cached_value = self._cache.get(self._cache_key(content))
if cached_value:
return cached_value

def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, json.dumps(compressed_content))
self._cache.set(self._cache_key(content), value)

def _cache_key(self, content: Union[str, List[Dict]]) -> str:
return f"{json.dumps(content)}_{self._min_tokens}"

def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")


def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
if not min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
Expand All @@ -286,3 +425,12 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count


def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"websockets": ["websockets>=12.0,<13"],
"jupyter-executor": jupyter_executor,
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
"long-context": ["llmlingua<0.3"],
}

setuptools.setup(
Expand Down
Loading
Loading