Skip to content

Commit

Permalink
Text Compression Transform (microsoft#2225)
Browse files Browse the repository at this point in the history
* adds implementation

* handles optional import

* cleanup

* updates github workflows

* skip test if dependencies not installed

* skip test if dependencies not installed

* use cpu

* skip openai

* unskip openai

* adds protocol

* better docstr

* minor fixes

* updates optional dependencies docs

* wip

* update docstrings

* wip

* adds back llmlingua requirement

* finalized protocol

* improve docstr

* guide complete

* improve docstr

* fix FAQ

* added cache support

* improve cache key

* cache key fix + faq fix

* improve docs

* improve guide

* args -> params

* spelling
  • Loading branch information
WaelKarkoub authored May 6, 2024
1 parent 25bd6b2 commit 69bce13
Show file tree
Hide file tree
Showing 10 changed files with 503 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ jobs:
pip install pytest-cov>=5
- name: Install packages and dependencies for Transform Messages
run: |
pip install -e .
pip install -e '.[long-context]'
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
Expand Down
68 changes: 68 additions & 0 deletions autogen/agentchat/contrib/capabilities/text_compressors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, Optional, Protocol

IMPORT_ERROR: Optional[Exception] = None
try:
import llmlingua
except ImportError:
IMPORT_ERROR = ImportError(
"LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
)
PromptCompressor = object
else:
from llmlingua import PromptCompressor


class TextCompressor(Protocol):
"""Defines a protocol for text compression to optimize agent interactions."""

def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
"""This method takes a string as input and returns a dictionary containing the compressed text and other
relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
"""
...


class LLMLingua:
"""Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
and the specific configurations used for the PromptCompressor.
"""

def __init__(
self,
prompt_compressor_kwargs: Dict = dict(
model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2=True,
device_map="cpu",
),
structured_compression: bool = False,
) -> None:
"""
Args:
prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
use_llmlingua2 set to True, and device_map set to "cpu".
structured_compression (bool): A flag indicating whether to use structured compression. If True, the
structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
is used. Defaults to False.
dictionary.
Raises:
ImportError: If the llmlingua library is not installed.
"""
if IMPORT_ERROR:
raise IMPORT_ERROR

self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)

assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
self._compression_method = (
self._prompt_compressor.structured_compress_prompt
if structured_compression
else self._prompt_compressor.compress_prompt
)

def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
return self._compression_method([text], **compression_params)
178 changes: 163 additions & 15 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import copy
import json
import sys
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

import tiktoken
from termcolor import colored

from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache

from .text_compressors import LLMLingua, TextCompressor


class MessageTransform(Protocol):
Expand Down Expand Up @@ -156,7 +160,7 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
assert self._min_tokens is not None

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not self._are_min_tokens_reached(messages):
if not _min_tokens_reached(messages, self._min_tokens):
return messages

temp_messages = copy.deepcopy(messages)
Expand Down Expand Up @@ -205,19 +209,6 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages:
return logs_str, True
return "No tokens were truncated.", False

def _are_min_tokens_reached(self, messages: List[Dict]) -> bool:
"""
Returns True if no minimum tokens restrictions are applied.
Either if the total number of tokens in the messages is greater than or equal to the `min_theshold_tokens`,
or no minimum tokens threshold is set.
"""
if not self._min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= self._min_tokens

def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents, n_tokens)
Expand Down Expand Up @@ -268,7 +259,7 @@ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int

return max_tokens if max_tokens is not None else sys.maxsize

def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
if min_tokens is None:
return 0
if min_tokens < 0:
Expand All @@ -278,6 +269,154 @@ def _validate_min_tokens(self, min_tokens: int, max_tokens: int) -> int:
return min_tokens


class TextMessageCompressor:
"""A transform for compressing text messages in a conversation history.
It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
processing and response generation by downstream models.
"""

def __init__(
self,
text_compressor: Optional[TextCompressor] = None,
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
):
"""
Args:
text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
protocol. If None, it defaults to LLMLingua.
min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
than or equal to 0 if not None. If None, no threshold-based compression is applied.
compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
"""

if text_compressor is None:
text_compressor = LLMLingua()

self._validate_min_tokens(min_tokens)

self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._cache = cache

# Optimizing savings calculations to optimize log generation
self._recent_tokens_savings = 0

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies compression to messages in a conversation history based on the specified configuration.
The function processes each message according to the `compression_args` and `min_tokens` settings, applying
the specified compression configuration and returning a new list of messages with reduced token counts
where possible.
Args:
messages (List[Dict]): A list of message dictionaries to be compressed.
Returns:
List[Dict]: A list of dictionaries with the message content compressed according to the configured
method and scope.
"""
# Make sure there is at least one message
if not messages:
return messages

# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
return messages

total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not isinstance(message.get("content"), (str, list)):
continue

if _is_content_text_empty(message["content"]):
continue

cached_content = self._cache_get(message["content"])
if cached_content is not None:
savings, compressed_content = cached_content
else:
savings, compressed_content = self._compress(message["content"])

self._cache_set(message["content"], compressed_content, savings)

message["content"] = compressed_content
total_savings += savings

self._recent_tokens_savings = total_savings
return processed_messages

def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
if self._recent_tokens_savings > 0:
return f"{self._recent_tokens_savings} tokens saved with text compression.", True
else:
return "No tokens saved with text compression.", False

def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return 0, content

def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
tokens_saved = 0
for msg in content:
if "text" in msg:
savings, msg["text"] = self._compress_text(msg["text"])
tokens_saved += savings
return tokens_saved, content

def _compress_text(self, text: str) -> Tuple[int, str]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)

savings = 0
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]

return savings, compressed_text["compressed_prompt"]

def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
if self._cache:
cached_value = self._cache.get(self._cache_key(content))
if cached_value:
return cached_value

def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, json.dumps(compressed_content))
self._cache.set(self._cache_key(content), value)

def _cache_key(self, content: Union[str, List[Dict]]) -> str:
return f"{json.dumps(content)}_{self._min_tokens}"

def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")


def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
if not min_tokens:
return True

messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
Expand All @@ -286,3 +425,12 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count


def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"websockets": ["websockets>=12.0,<13"],
"jupyter-executor": jupyter_executor,
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
"long-context": ["llmlingua<0.3"],
}

setuptools.setup(
Expand Down
Loading

0 comments on commit 69bce13

Please sign in to comment.