diff --git a/reme_ai/__init__.py b/reme_ai/__init__.py index 62450870..ffbb4ec8 100644 --- a/reme_ai/__init__.py +++ b/reme_ai/__init__.py @@ -8,6 +8,7 @@ from . import agent # noqa: E402 from . import config # noqa: E402 from . import constants # noqa: E402 +from . import context # noqa: E402 from . import enumeration # noqa: E402 from . import retrieve # noqa: E402 from . import schema # noqa: E402 @@ -21,6 +22,7 @@ "agent", "config", "constants", + "context", "enumeration", "retrieve", "schema", diff --git a/reme_ai/config/default.yaml b/reme_ai/config/default.yaml index 8e8c233f..4c8fdc81 100644 --- a/reme_ai/config/default.yaml +++ b/reme_ai/config/default.yaml @@ -156,6 +156,83 @@ flow: description: "user query" required: true + context_offload: + flow_content: ContextOffloadOp() >> BatchWriteFileOp() + description: "Manages context window limits by compacting tool messages and compressing conversation history. First compacts large tool messages by storing full content in external files, then applies LLM-based compression if compaction ratio exceeds threshold. This helps reduce token usage while preserving important information." + input_schema: + messages: + type: array + description: "List of conversation messages to process for context offloading" + required: true + context_manage_mode: + type: string + description: "Context management mode: 'compact' only applies compaction to tool messages, 'compress' only applies LLM-based compression, 'auto' applies compaction first then compression if compaction ratio exceeds threshold. Defaults to 'auto'." + required: false + enum: ["compact", "compress", "auto"] + max_total_tokens: + type: integer + description: "Maximum token count threshold for triggering compression/compaction. For compaction, this is the total token count threshold. For compression, this excludes keep_recent_count messages and system messages. Defaults to 20000." + required: false + max_tool_message_tokens: + type: integer + description: "Maximum token count per tool message before compaction is applied. Tool messages exceeding this threshold will have their full content stored in external files with only a preview kept in context. Defaults to 2000." + required: false + group_token_threshold: + type: integer + description: "Maximum token count per compression group when using LLM-based compression. If None or 0, all messages are compressed in a single group. Messages exceeding this threshold individually will form their own group. Only used in 'compress' or 'auto' mode." + required: false + keep_recent_count: + type: integer + description: "Number of recent messages to preserve without compression or compaction. These messages remain unchanged to maintain conversation context. Defaults to 1 for compaction and 2 for compression." + required: false + store_dir: + type: string + description: "Directory path for storing offloaded message content. Full tool message content and compressed message groups are saved as files in this directory. Required for compaction and compression operations." + required: false + chat_id: + type: string + description: "Unique identifier for the chat session, used for file naming when storing compressed message groups. If not provided, a UUID will be generated automatically." + required: false + + context_offload_for_agentscope: + flow_content: ContextOffloadOp() + description: "Context offload operation for AgentScope integration. Manages context window limits by compacting tool messages and compressing conversation history without batch file writing. Same functionality as context_offload but without the BatchWriteFileOp step." + input_schema: + messages: + type: array + description: "List of conversation messages to process for context offloading" + required: true + context_manage_mode: + type: string + description: "Context management mode: 'compact' only applies compaction to tool messages, 'compress' only applies LLM-based compression, 'auto' applies compaction first then compression if compaction ratio exceeds threshold. Defaults to 'auto'." + required: false + enum: ["compact", "compress", "auto"] + max_total_tokens: + type: integer + description: "Maximum token count threshold for triggering compression/compaction. For compaction, this is the total token count threshold. For compression, this excludes keep_recent_count messages and system messages. Defaults to 20000." + required: false + max_tool_message_tokens: + type: integer + description: "Maximum token count per tool message before compaction is applied. Tool messages exceeding this threshold will have their full content stored in external files with only a preview kept in context. Defaults to 2000." + required: false + group_token_threshold: + type: integer + description: "Maximum token count per compression group when using LLM-based compression. If None or 0, all messages are compressed in a single group. Messages exceeding this threshold individually will form their own group. Only used in 'compress' or 'auto' mode." + required: false + keep_recent_count: + type: integer + description: "Number of recent messages to preserve without compression or compaction. These messages remain unchanged to maintain conversation context. Defaults to 1 for compaction and 2 for compression." + required: false + store_dir: + type: string + description: "Directory path for storing offloaded message content. Full tool message content and compressed message groups are saved as files in this directory. Required for compaction and compression operations." + required: false + chat_id: + type: string + description: "Unique identifier for the chat session, used for file naming when storing compressed message groups. If not provided, a UUID will be generated automatically." + required: false + + llm: default: backend: openai_compatible @@ -164,9 +241,7 @@ llm: temperature: 0.6 token_count: # Optional model_name: Qwen/Qwen3-30B-A3B-Instruct-2507 - backend: hf - params: - use_mirror: true + backend: base qwen3_30b_instruct: backend: openai_compatible diff --git a/reme_ai/context/__init__.py b/reme_ai/context/__init__.py index e69de29b..aa96c17a 100644 --- a/reme_ai/context/__init__.py +++ b/reme_ai/context/__init__.py @@ -0,0 +1,14 @@ +"""Context management module for ReMe framework. + +This module provides submodules for different types of context management operations: +- file_tool: File-related operations for reading, writing, and searching files +- offload: Context offload operations for reducing token usage and managing context windows +""" + +from . import file_tool +from . import offload + +__all__ = [ + "file_tool", + "offload", +] diff --git a/reme_ai/context/offload/__init__.py b/reme_ai/context/offload/__init__.py index e69de29b..626fe0fd 100644 --- a/reme_ai/context/offload/__init__.py +++ b/reme_ai/context/offload/__init__.py @@ -0,0 +1,19 @@ +"""Context offload package for ReMe framework. + +This package provides context management operations that can be used in LLM-powered flows +to reduce token usage and manage context window limits. It includes ready-to-use operations for: + +- ContextCompactOp: Compact tool messages by storing full content in external files +- ContextCompressOp: Compress conversation history using LLM to generate concise summaries +- ContextOffloadOp: Orchestrate compaction and compression to reduce token usage +""" + +from .context_compact_op import ContextCompactOp +from .context_compress_op import ContextCompressOp +from .context_offload_op import ContextOffloadOp + +__all__ = [ + "ContextCompactOp", + "ContextCompressOp", + "ContextOffloadOp", +] diff --git a/reme_ai/context/offload/context_compact_op.py b/reme_ai/context/offload/context_compact_op.py index 9f5a81ef..7bd878a1 100644 --- a/reme_ai/context/offload/context_compact_op.py +++ b/reme_ai/context/offload/context_compact_op.py @@ -6,9 +6,7 @@ This helps manage context window limits while preserving important information. """ -import json from pathlib import Path -from typing import List from uuid import uuid4 from flowllm.core.context import C @@ -28,36 +26,6 @@ class ContextCompactOp(BaseAsyncOp): This helps manage context window limits while preserving recent tool messages. """ - def __init__( - self, - all_token_threshold: int = 20000, - tool_token_threshold: int = 2000, - tool_left_char_len: int = 100, - keep_recent: int = 1, - storage_path: str = "./", - exclude_tools: List[str] = None, - **kwargs, - ): - """ - Initialize the context compaction operation. - - Args: - all_token_threshold: Maximum total token count before compaction is triggered. - tool_token_threshold: Maximum token count for a single tool message before it's compressed. - tool_left_char_len: Number of characters to keep in the compressed tool message preview. - keep_recent: Number of recent tool messages to keep uncompressed. - storage_path: Directory path where compressed tool message contents will be stored. - exclude_tools: List of tool names to exclude from compaction (not currently used). - **kwargs: Additional arguments passed to the base class. - """ - super().__init__(**kwargs) - self.all_token_threshold: int = all_token_threshold - self.tool_token_threshold: int = tool_token_threshold - self.tool_left_char_len: int = tool_left_char_len - self.keep_recent: int = keep_recent - self.storage_path: Path = Path(storage_path) - self.exclude_tools: List[str] = exclude_tools - async def async_execute(self): """ Execute the context compaction operation. @@ -70,41 +38,44 @@ async def async_execute(self): - Storing full content in external files - Preserving recent tool messages """ + # Get configuration from context + max_total_tokens: int = self.context.get("max_total_tokens", 20000) + max_tool_message_tokens: int = self.context.get("max_tool_message_tokens", 2000) + preview_char_length: int = self.context.get("preview_char_length", 100) + keep_recent_count: int = self.context.get("keep_recent_count", 1) + store_dir: Path = Path(self.context.get("store_dir", "")) + + assert max_total_tokens > 0, "max_total_tokens must be greater than 0" + assert max_tool_message_tokens > 0, "max_tool_message_tokens must be greater than 0" + assert preview_char_length >= 0, "preview_char_length must be greater than 0" + assert keep_recent_count > 0, "keep_recent_count must be greater than 0" + # Convert context messages to Message objects messages = [Message(**x) for x in self.context.messages] + messages_to_compress = [x for x in messages if x.role is not Role.SYSTEM][:-keep_recent_count] - # Calculate total token count - token_cnt: int = self.token_count(messages) - logger.info(f"Context compaction check: total token count={token_cnt}, threshold={self.all_token_threshold}") - - # If token count is within threshold, no compaction needed - if token_cnt <= self.all_token_threshold: + # If nothing to compress after filtering, return original messages + if not messages_to_compress: self.context.response.answer = self.context.messages - logger.info( - f"Token count ({token_cnt}) is within threshold ({self.all_token_threshold}), no compaction needed", - ) + logger.info("No messages to compress after filtering, returning original messages") return - # Filter tool messages for processing - tool_messages = [x for x in messages if x.role is Role.TOOL] + logger.info(f"{len(messages_to_compress)} messages remaining for compression check") + + # Calculate total token count + compact_token_cnt: int = self.token_count(messages_to_compress) + logger.info(f"Context compaction check: total token count={compact_token_cnt}, threshold={max_total_tokens}") - # If there are too few tool messages, no compaction needed - if len(tool_messages) <= self.keep_recent: + # If token count is within threshold, no compaction needed + if compact_token_cnt <= max_total_tokens: self.context.response.answer = self.context.messages - logger.info( - f"Tool message count ({len(tool_messages)}) is less than or " - f"equal to keep_recent ({self.keep_recent}), no compaction needed", - ) + logger.info(f"Token count ({compact_token_cnt}) is within ({max_total_tokens}), no compaction needed") return - # Exclude recent tool messages from compaction (keep them intact) - tool_messages = tool_messages[: -self.keep_recent] - logger.info( - f"Processing {len(tool_messages)} tool messages for " - f"compaction (keeping {self.keep_recent} recent messages)", - ) + # Filter tool messages for processing + tool_messages = [x for x in messages_to_compress if x.role is Role.TOOL] - # Dictionary to store file paths and their compressed content (for potential batch writing) + # Dictionary to store file paths and their full content (for potential batch writing) write_file_dict = {} # Process each tool message @@ -113,33 +84,56 @@ async def async_execute(self): tool_token_cnt = self.token_count([tool_message]) # Skip if token count is within threshold - if tool_token_cnt <= self.tool_token_threshold: + if tool_token_cnt <= max_tool_message_tokens: logger.info( f"Skipping tool message (tool_call_id={tool_message.tool_call_id}): " - f"token count ({tool_token_cnt}) is within threshold ({self.tool_token_threshold})", + f"token count ({tool_token_cnt}) is within threshold ({max_tool_message_tokens})", ) continue - # Create compressed preview of the tool message content - compact_result = tool_message.content[: self.tool_left_char_len] + "..." + # Save original full content before modifying + original_content = tool_message.content # Generate file name from tool_call_id or create a unique identifier file_name = tool_message.tool_call_id or uuid4().hex - path = self.storage_path / f"{file_name}.txt" + store_path = store_dir / f"{file_name}.txt" + + # Store the full content for batch writing + write_file_dict[store_path.as_posix()] = original_content - # Store the mapping for potential batch writing - write_file_dict[str(path)] = compact_result + # Create compressed preview of the tool message content + compact_result = original_content[:preview_char_length] + "..." # Log the compaction action logger.info( f"Compacting tool message (tool_call_id={tool_message.tool_call_id}): " - f"token count={tool_token_cnt}, saving full content to {path}", + f"token count={tool_token_cnt}, saving full content to {store_path}", ) # Update tool message content with preview and file reference - compact_result += f" (detailed result is stored in {path})" + compact_result += f" (detailed result is stored in {store_path})" tool_message.content = compact_result + # Store write_file_dict in context for potential batch writing + if write_file_dict: + self.context.write_file_dict = write_file_dict + # Return the compacted messages as JSON - self.context.response.answer = json.dumps([x.simple_dump() for x in messages], ensure_ascii=False, indent=2) + self.context.response.answer = [x.simple_dump() for x in messages] + self.context.response.metadata["write_file_dict"] = write_file_dict + logger.info(f"Context compaction completed: {len(write_file_dict)} tool messages were compacted") + + async def async_default_execute(self, e: Exception = None, **_kwargs): + """Handle execution errors by returning original messages. + + This method is called when an exception occurs during async_execute. It preserves + the original messages and marks the operation as unsuccessful. + + Args: + e: The exception that occurred during execution, if any. + **_kwargs: Additional keyword arguments (unused but required by interface). + """ + self.context.response.answer = self.context.messages + self.context.response.success = False + self.context.response.metadata["error"] = str(e) diff --git a/reme_ai/context/offload/context_compress_op.py b/reme_ai/context/offload/context_compress_op.py index 7cc2c141..218f6e21 100644 --- a/reme_ai/context/offload/context_compress_op.py +++ b/reme_ai/context/offload/context_compress_op.py @@ -4,22 +4,29 @@ This module provides functionality to compress conversation history by using a language model to generate concise summaries of older messages while preserving recent messages. This helps manage context window limits while maintaining conversation coherence. + +The compression process: +1. Identifies messages that exceed token thresholds +2. Splits messages into groups if needed +3. Uses LLM to generate compressed summaries of older message groups +4. Stores original messages to files for potential retrieval +5. Appends compressed summaries to the system message while preserving recent messages """ import json -import re -import xml.etree.ElementTree as ET from pathlib import Path -from typing import List +from typing import List, Tuple from uuid import uuid4 from flowllm.core.context import C from flowllm.core.enumeration import Role from flowllm.core.op import BaseAsyncOp from flowllm.core.schema import Message -from flowllm.core.utils import extract_content from loguru import logger +from reme_ai.utils import merge_messages_content +from reme_ai.utils.op_utils import extract_xml_tag_content + @C.register_op() class ContextCompressOp(BaseAsyncOp): @@ -29,90 +36,46 @@ class ContextCompressOp(BaseAsyncOp): When the total token count exceeds the threshold, this operation uses a language model to compress older messages into a concise summary while keeping recent messages intact. This preserves conversation context while reducing token usage. + + Attributes: + file_path: Path to the operation file, used for configuration. + + Context Parameters: + max_total_tokens (int): Maximum token count threshold for compression. + Defaults to 20000. Does not include keep_recent_count messages or system messages. + group_token_threshold (int, optional): Maximum token count per compression group. + If None or 0, all messages are compressed in a single group. + keep_recent_count (int): Number of recent messages to preserve without compression. + Defaults to 2. Must be non-negative. + chat_id (str): Unique identifier for the chat session, used for file naming. + Defaults to a generated UUID if not provided. """ file_path: str = __file__ - def __init__( - self, - all_token_threshold: int = 20000, - keep_recent: int = 5, - storage_path: str = "./compressed_contexts", - micro_summary_token_threshold: int = None, - **kwargs, - ): - """ - Initialize the context compression operation. - - Args: - all_token_threshold: Maximum total token count before compression is triggered. - keep_recent: Number of recent messages to keep uncompressed. - storage_path: Directory path where original messages will be stored for traceability. - micro_summary_token_threshold: Token threshold for each compression group. - If set, messages will be split into groups of this size and compressed separately. - If None, all messages will be compressed together. - **kwargs: Additional arguments passed to the base class. - - Note: - System messages are NEVER compressed to preserve important system instructions. - """ - super().__init__(**kwargs) - self.all_token_threshold: int = all_token_threshold - self.keep_recent: int = keep_recent - self.storage_path: Path = Path(storage_path) - self.micro_summary_token_threshold: int = micro_summary_token_threshold - - assert ( - micro_summary_token_threshold is None or micro_summary_token_threshold > 0 - ), "Micro summary token threshold must be greater than 0" - - # Create storage directory if it doesn't exist - self.storage_path.mkdir(parents=True, exist_ok=True) - - def _save_original_messages(self, messages: List[Message]) -> str: - """Save original messages to file for traceability. + def get_store_path(self, name: str) -> Path: + """Get the storage path for a given file name. Args: - messages: List of messages to save + name: Name of the file to store. Returns: - Path to the saved file + Path object representing the full path to the storage location. """ - # Generate unique filename with timestamp - file_name = f"context_{uuid4().hex}.txt" - file_path = self.storage_path / file_name - - # Convert messages to serializable format - messages_data = [ - { - "role": msg.role.value if hasattr(msg.role, "value") else str(msg.role), - "content": msg.content, - "name": getattr(msg, "name", None), - "tool_call_id": getattr(msg, "tool_call_id", None), - } - for msg in messages - ] - - # Save to file - with open(file_path, "w", encoding="utf-8") as f: - json.dump(messages_data, f, ensure_ascii=False, indent=2) - - logger.info(f"Saved {len(messages)} original messages to {file_path}") - return str(file_path) - - def _split_messages_by_token_threshold( - self, - messages: List[Message], - token_threshold: int, - ) -> List[List[Message]]: + return Path(self.context.store_dir) / name + + def _split_messages_by_token_threshold(self, messages: List[Message], token_threshold: int) -> List[List[Message]]: """Split messages into groups based on token threshold. + Messages are grouped such that each group's token count does not exceed the threshold, + except when a single message exceeds the threshold, in which case it forms its own group. + Args: messages: List of messages to split - token_threshold: Maximum token count for each group + token_threshold: Maximum token count for each group (may be exceeded by single messages) Returns: - List of message groups, each within the token threshold + List of message groups, where each group attempts to stay within the token threshold """ if not messages: return [] @@ -146,112 +109,43 @@ def _split_messages_by_token_threshold( if current_group: groups.append(current_group) - logger.info( - f"Split {len(messages)} messages into {len(groups)} groups " f"with token threshold {token_threshold}", - ) + logger.info(f"Split {len(messages)} messages into {len(groups)} groups with token threshold {token_threshold}") return groups - @staticmethod - def _extract_xml_fragments(text: str) -> str: - """ - Extract XML fragments from text, removing scratchpad elements. - - Scans text to extract complete and parseable top-level XML fragments, - excluding elements. If state_snapshot XML is found, returns it; - otherwise returns the original text. - - Args: - text: Input text potentially containing XML fragments - - Returns: - Extracted XML content or original text - """ - try: - # Remove scratchpad elements - new_text = re.sub(r".*?", "", text, flags=re.S | re.I) - # Extract balanced XML tags - extract_xml = [m[0] for m in re.findall(r"(<(\w+)[^>]*>(?:[^<]|<(?!/\2))*)", new_text)] - - # Validate XML parsing - valid_xml = [] - for xml_str in extract_xml: - try: - ET.fromstring(xml_str) - valid_xml.append(xml_str) - except ET.ParseError: - continue - - # Return state_snapshot if found, otherwise original text - if valid_xml and any("" in xml for xml in valid_xml): - return next(xml for xml in valid_xml if "" in xml) - elif valid_xml: - return valid_xml[0] - else: - return text - except Exception as e: - logger.warning(f"Failed to extract XML fragments: {e}. Returning original text.") - return text + async def _compress_messages_with_llm(self, messages_to_compress: List[Message]) -> str: + """Compress a list of messages using LLM to generate a summary. - @staticmethod - def _format_messages_for_compression(messages: List[Message]) -> str: - """Format messages into a readable text for compression. + This method formats the messages into a prompt, sends it to the LLM, and extracts + the compressed state snapshot from the response. The LLM response is expected to + contain XML tags for scratchpad and state_snapshot. Args: - messages: List of messages to format + messages_to_compress: List of Message objects to compress into a summary. Returns: - Formatted string representation of messages - """ - lines = [] - for i, msg in enumerate(messages, 1): - role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - lines.append(f"[Message {i} - {role_name}]") - lines.append(msg.content) - lines.append("") # Empty line between messages - - return "\n".join(lines) - - async def _compress_messages_with_llm(self, messages_to_compress: List[Message]) -> str: - """Use LLM to compress messages into a concise summary. + Compressed summary string extracted from the LLM response. Returns empty string + if LLM returns None or if state_snapshot cannot be extracted. - Args: - messages_to_compress: List of messages to compress - - Returns: - Compressed summary text + Note: + If state_snapshot extraction fails, the full content is used as fallback. """ - # Format messages for the prompt - formatted_messages = self._format_messages_for_compression(messages_to_compress) - - # Create prompt for compression prompt = self.prompt_format( - prompt_name="compress_context_prompt", - messages_content=formatted_messages, + "compress_context_prompt", + messages_content=merge_messages_content(messages_to_compress), ) def parse_compressed_result(message: Message) -> str: - """Parse LLM response to extract compressed content. - - Args: - message: LLM response message - - Returns: - Compressed content string - """ content = message.content.strip() - # Try to extract content from txt code block - compressed = extract_content(content, "txt") - # If no code block found, use the raw content - if not compressed: - compressed = content + scratchpad = extract_xml_tag_content(content, "scratchpad") + state_snapshot = extract_xml_tag_content(content, "state_snapshot") + logger.info(f"Parsed scratchpad: \n{scratchpad} \nstate_snapshot: \n{state_snapshot}") - logger.info( - f"Compressed {len(messages_to_compress)} messages into " - f"{len(compressed)} characters (reduction: " - f"{len(formatted_messages)} -> {len(compressed)})", - ) - return compressed + if state_snapshot is None: + logger.warning("Failed to extract state_snapshot from LLM response, using full content as fallback") + return content + + return state_snapshot # Call LLM to generate compressed summary result = await self.llm.achat( @@ -259,158 +153,81 @@ def parse_compressed_result(message: Message) -> str: callback_fn=parse_compressed_result, ) + if result is None: + logger.error("LLM returned None, using empty string as fallback") + return "" + return result - async def _compress_with_micro_groups( + async def _compress_with_groups( self, - messages_to_compress: List[Message], - system_messages: List[Message], - recent_messages: List[Message], - ) -> List[Message]: - """Compress messages by splitting into groups and compressing each separately. + system_message: Message, + message_groups: List[List[Message]], + ) -> Tuple[dict, list]: + """Compress multiple message groups and prepare them for storage. + + This method processes each message group, compresses it using LLM, and determines + whether compression is beneficial. If compression reduces token count, the original + messages are saved to files and compressed summaries are appended to the system + message. Otherwise, original messages are preserved in the return list. Args: - messages_to_compress: Messages to be compressed - system_messages: System messages to preserve - recent_messages: Recent messages to keep uncompressed + system_message: The system message to append compressed summaries to. + message_groups: List of message groups, where each group is a list of Message + objects to be compressed together. Returns: - List of new messages after compression + A tuple containing: + - write_file_dict: Dictionary mapping file paths to JSON-serialized message + strings for messages that were successfully compressed. + - return_messages: List of Message objects including the modified system + message with compressed summaries and any messages that couldn't be + compressed or didn't benefit from compression. """ - # Split messages into groups based on micro threshold - message_groups = self._split_messages_by_token_threshold( - messages_to_compress, - self.micro_summary_token_threshold, - ) + write_file_dict = {} + return_messages = [] + chat_id: str = self.context.get("chat_id", uuid4().hex) - # Compress each group separately - compressed_messages = [] - total_original_tokens = 0 - total_compressed_tokens = 0 + # Create a copy of system_message to avoid modifying the original + system_message_copy = Message(role=system_message.role, content=system_message.content) - for group_idx, group in enumerate(message_groups, 1): - # Calculate original token count for this group - group_original_tokens = self.token_count(group) - total_original_tokens += group_original_tokens + for g_idx, messages in enumerate(message_groups): + group_original_tokens = self.token_count(messages) + messages_str = json.dumps([x.simple_dump() for x in messages], ensure_ascii=False, indent=2) + store_path = Path(self.context.get("store_dir", "")) / f"{chat_id}_{g_idx}.json" - # Save original messages for this group - group_file_path = self._save_original_messages(group) + logger.info(f"Compress {g_idx}/{len(message_groups)} ({len(messages)}, {group_original_tokens} tokens)") + group_summary = await self._compress_messages_with_llm(messages) - # Compress this group - logger.info( - f"Compressing group {group_idx}/{len(message_groups)} " - f"({len(group)} messages, {group_original_tokens} tokens)", - ) - group_summary = await self._compress_messages_with_llm(group) - group_summary = self._extract_xml_fragments(group_summary) - - # Create compressed message for this group - compressed_message = Message( - role=Role.SYSTEM, - content=( - f"[Compressed conversation history - Part {group_idx}/{len(message_groups)}]\n" - f"{group_summary}\n\n" - f"(Original {len(group)} messages are stored in: {group_file_path})" - ), - ) + if not group_summary: + logger.warning(f"Group {g_idx} compression returned empty summary, using original messages.") + return_messages.extend(messages) + continue - # Check if compression actually reduced tokens for this group - compressed_tokens = self.token_count([compressed_message]) + compress_content = ( + f"[Compressed conversation history - Part {g_idx}/{len(message_groups)}]\n{group_summary}\n\n" + f"(Original {len(messages)} messages are stored in: {store_path.as_posix()})\n" + ) + compressed_tokens = self.token_count([Message(content=compress_content)]) if compressed_tokens >= group_original_tokens: logger.warning( - f"Group {group_idx} compression did not reduce tokens: " + f"Group {g_idx} compression did not reduce tokens: " f"{group_original_tokens} -> {compressed_tokens}. Using original messages.", ) - return None - - logger.info( - f"Group {group_idx} compression successful: " - f"{group_original_tokens} -> {compressed_tokens} tokens " - f"(reduction: {group_original_tokens - compressed_tokens} tokens, " - f"{100 * (1 - compressed_tokens / group_original_tokens):.1f}%)", - ) - compressed_messages.append(compressed_message) - total_compressed_tokens += compressed_tokens - - # Construct new message list: system messages + all compressed messages + recent messages - new_messages = system_messages + compressed_messages + recent_messages - - logger.info( - f"Context compression completed using micro-compression: " - f"{len(messages_to_compress) + len(system_messages) + len(recent_messages)} messages -> " - f"{len(new_messages)} messages ({len(message_groups)} compressed groups), " - f"total tokens: {total_original_tokens} -> {total_compressed_tokens}", - ) - - return new_messages - - async def _compress_all_together( - self, - messages_to_compress: List[Message], - system_messages: List[Message], - recent_messages: List[Message], - ) -> List[Message]: - """Compress all messages together into a single summary. - - Args: - messages_to_compress: Messages to be compressed - system_messages: System messages to preserve - recent_messages: Recent messages to keep uncompressed - - Returns: - List of new messages after compression - """ - # Calculate original token count - original_tokens = self.token_count(messages_to_compress) - - # Save original messages to file for traceability - original_file_path = self._save_original_messages(messages_to_compress) - - # Use LLM to compress messages - logger.info( - f"Starting LLM compression of {len(messages_to_compress)} messages " - f"({original_tokens} tokens), keeping {len(recent_messages)} recent messages", - ) - compressed_summary = await self._compress_messages_with_llm(messages_to_compress) - compressed_summary = self._extract_xml_fragments(compressed_summary) - - # Create a new system message with the compressed content and file reference - compressed_message = Message( - role=Role.SYSTEM, - content=( - f"[Compressed conversation history]\n" - f"{compressed_summary}\n\n" - f"(Original {len(messages_to_compress)} messages are stored in: {original_file_path})" - ), - ) - - # Check if compression actually reduced tokens - compressed_tokens = self.token_count([compressed_message]) - - if compressed_tokens >= original_tokens: - logger.warning( - f"Compression did not reduce tokens: {original_tokens} -> {compressed_tokens}. " - f"Returning original messages.", - ) - return None - - logger.info( - f"Compression successful: {original_tokens} -> {compressed_tokens} tokens " - f"(reduction: {original_tokens - compressed_tokens} tokens, " - f"{100 * (1 - compressed_tokens / original_tokens):.1f}%)", - ) - - # Construct new message list: system messages + compressed message + recent messages - new_messages = system_messages + [compressed_message] + recent_messages - - logger.info( - f"Context compression completed: " - f"{len(messages_to_compress) + len(system_messages) + len(recent_messages)} messages -> " - f"{len(new_messages)} messages", - ) + return_messages.extend(messages) + else: + system_message_copy.content += compress_content + "\n\n" + write_file_dict[store_path.as_posix()] = messages_str + logger.info( + f"Group {g_idx} compression successful: " + f"{group_original_tokens} -> {compressed_tokens} tokens " + f"(reduction: {group_original_tokens - compressed_tokens} tokens, " + f"{100 * (1 - compressed_tokens / group_original_tokens):.1f}%)", + ) - return new_messages + return_messages = [system_message_copy] + return_messages + return write_file_dict, return_messages async def async_execute(self): """ @@ -423,31 +240,33 @@ async def async_execute(self): 4. Otherwise, uses LLM to compress older messages by: - Saving original messages to file - Generating a concise summary of older messages - - Replacing older messages with a single summary message + - Appending compressed summaries to the system message + - Preserving messages that couldn't be compressed or didn't benefit from compression """ + # Get configuration from context + # Note: max_total_tokens does not include keep_recent_count messages or system messages + max_total_tokens: int = self.context.get("max_total_tokens", 20000) + group_token_threshold: int = self.context.get("group_token_threshold", None) + keep_recent_count: int = self.context.get("keep_recent_count", 2) + + assert max_total_tokens > 0, "max_total_tokens must be positive" + assert keep_recent_count >= 0, "keep_recent_count must be non-negative" + # Convert context messages to Message objects messages = [Message(**x) for x in self.context.messages] - # Check if we have enough messages to compress - if len(messages) <= self.keep_recent: - self.context.response.answer = self.context.messages - logger.info( - f"Message count ({len(messages)}) is less than or " - f"equal to keep_recent ({self.keep_recent}), no compression needed", - ) - return + # Extract system message (should be exactly one) + system_message = [x for x in messages if x.role is Role.SYSTEM] + assert len(system_message) <= 1, f"Expected at most one system message, got {len(system_message)}" - # Split messages into those to compress and those to keep - messages_to_compress = messages[: -self.keep_recent] - recent_messages = messages[-self.keep_recent :] + if len(system_message) == 0: + system_message = Message(role=Role.SYSTEM, content="") + else: + system_message = system_message[0] - # Always filter out system messages (system messages are never compressed) - system_messages = [m for m in messages_to_compress if m.role is Role.SYSTEM] - messages_to_compress = [m for m in messages_to_compress if m.role is not Role.SYSTEM] - logger.info( - f"Excluding {len(system_messages)} system messages from compression, " - f"{len(messages_to_compress)} messages remaining for compression check", - ) + messages_without_system = [x for x in messages if x.role is not Role.SYSTEM] + messages_to_compress = messages_without_system[:-keep_recent_count] + recent_messages = messages_without_system[-keep_recent_count:] # If nothing to compress after filtering, return original messages if not messages_to_compress: @@ -455,40 +274,42 @@ async def async_execute(self): logger.info("No messages to compress after filtering, returning original messages") return + logger.info(f"{len(messages_to_compress)} messages remaining for compression check") + # Calculate token count of messages to compress (only the content that will be compressed) compress_token_cnt: int = self.token_count(messages_to_compress) - logger.info( - f"Context compression check: messages_to_compress token count={compress_token_cnt}, " - f"threshold={self.all_token_threshold}", - ) + logger.info(f"Context compression check: token count={compress_token_cnt} threshold={max_total_tokens}") # If token count is within threshold, no compression needed - if compress_token_cnt <= self.all_token_threshold: + if compress_token_cnt <= max_total_tokens: self.context.response.answer = self.context.messages - logger.info( - f"Messages to compress token count ({compress_token_cnt}) is within threshold " - f"({self.all_token_threshold}), no compression needed", - ) + logger.info(f"messages_to_compress ({compress_token_cnt}) is within threshold ({max_total_tokens})") return - # Determine whether to use micro-compression (split into groups) or compress all together - if self.micro_summary_token_threshold is not None and self.micro_summary_token_threshold > 0: - new_messages = await self._compress_with_micro_groups( - messages_to_compress, - system_messages, - recent_messages, - ) + if group_token_threshold is not None and group_token_threshold > 0: + message_groups = self._split_messages_by_token_threshold(messages_to_compress, group_token_threshold) else: - new_messages = await self._compress_all_together( - messages_to_compress, - system_messages, - recent_messages, - ) + message_groups = [messages_to_compress] - # If compression failed (returned None), use original messages - if new_messages is None: - self.context.response.answer = self.context.messages - return + write_file_dict, return_messages = await self._compress_with_groups(system_message, message_groups) + + # Store write_file_dict in context for potential batch writing + if write_file_dict: + self.context.write_file_dict = write_file_dict - # Return the compressed messages as JSON - self.context.response.answer = json.dumps([x.model_dump() for x in new_messages], ensure_ascii=False, indent=2) + self.context.response.answer = [x.simple_dump() for x in (return_messages + recent_messages)] + self.context.response.metadata["write_file_dict"] = write_file_dict + + async def async_default_execute(self, e: Exception = None, **_kwargs): + """Handle execution errors by returning original messages. + + This method is called when an exception occurs during async_execute. It preserves + the original messages and marks the operation as unsuccessful. + + Args: + e: The exception that occurred during execution, if any. + **_kwargs: Additional keyword arguments (unused but required by interface). + """ + self.context.response.answer = self.context.messages + self.context.response.success = False + self.context.response.metadata["error"] = str(e) diff --git a/reme_ai/context/offload/context_compress_prompt.yaml b/reme_ai/context/offload/context_compress_prompt.yaml index 109ae94f..435d7bb1 100644 --- a/reme_ai/context/offload/context_compress_prompt.yaml +++ b/reme_ai/context/offload/context_compress_prompt.yaml @@ -1,22 +1,23 @@ -compress_context_prompt_zh: | - 你是将扮演一个功能为“将内部聊天记录总结为给定结构”的组件。 - - 当会话历史记录变得太大时,将调用您将整个历史记录提取为简洁的结构化XML快照。这个快照非常重要,因为它会成为Agent对过去内容**唯一**的记忆。后续的对话将Agent仅基于此snapshot恢复其工作。所有重要的细节、计划、错误和用户指令都必须保留。 - - 首先,你在个人的思考整个历史内容。检查用户的总体目标、代理的操作、工具输出、文件修改以及任何未解决的问题。找出对未来行动至关重要的每一条信息。 - - 在你推理完成后,生成最终的 XML对象。信息要非常密集。不要省略任何不重要的对话填充。 +compress_context_prompt: | + Here is the conversation history that needs to be compressed: + ’’’ + {messages_content} + ’’’ - 结构必须如下: + You are the component that summarizes internal chat history into a given structure. + When the conversation history grows too large, you will be invoked to distill the entire history into a concise, structured XML snapshot. This snapshot is CRITICAL, as it will become the agent's *only* memory of the past. The agent will resume its work based solely on this snapshot. All crucial details, plans, errors, and user directives MUST be preserved. + First, you will think through the entire history in a private . Review the user's overall goal, the agent's actions, tool outputs, file modifications, and any unresolved questions. Identify every piece of information that is essential for future actions. + After your reasoning is complete, generate the final XML object. Be incredibly dense with information. Omit any irrelevant conversational filler. + The structure MUST be as follows: - - + + - + + - - 下面是需要压缩的对话历史记录: + First, you will think through the entire history in a private . Then, generate the . - ’‘’ +compress_context_prompt_zh: | + 以下是需要压缩的对话历史: + ''' {messages_content} - ’‘’ - - 首先,您将在一个私有的中考虑整个历史。审查用户的总体目标。请你一定要使用中文来回答和压缩。 - - -compress_context_prompt: | - You are the component that summarizes internal chat history into a given structure. - - When the conversation history grows too large, you will be invoked to distill the entire history into a concise, structured XML snapshot. This snapshot is CRITICAL, as it will become the agent's *only* memory of the past. The agent will resume its work based solely on this snapshot. All crucial details, plans, errors, and user directives MUST be preserved. - - First, you will think through the entire history in a private . Review the user's overall goal, the agent's actions, tool outputs, file modifications, and any unresolved questions. Identify every piece of information that is essential for future actions. - - After your reasoning is complete, generate the final XML object. Be incredibly dense with information. Omit any irrelevant conversational filler. + ''' - The structure MUST be as follows: + 你是负责将内部聊天历史总结为给定结构的组件。 + 当对话历史变得过长时,你将被调用来将整个历史提炼成一个简洁、结构化的 XML 快照。这个快照至关重要,因为它将成为智能体对过去的*唯一*记忆。智能体将仅基于这个快照继续工作。所有关键细节、计划、错误和用户指令都必须被保留。 + 首先,你将在私有的 中思考整个历史。回顾用户的总体目标、智能体的行动、工具输出、文件修改以及任何未解决的问题。识别对未来行动至关重要的每一条信息。 + 完成推理后,生成最终的 XML 对象。信息要极其密集。省略任何无关的对话填充内容。 + 结构必须如下: - - + + - - + - - + - Here is the conversation history that needs to be compressed: - ’’’ - {messages_content} - ’’’ - First, you will think through the entire history in a private . Review the user's overall goal, the agent' \ No newline at end of file + 首先,你将在私有的 中思考整个历史。然后,生成 。 \ No newline at end of file diff --git a/reme_ai/context/offload/context_offload_op.py b/reme_ai/context/offload/context_offload_op.py new file mode 100644 index 00000000..1c086249 --- /dev/null +++ b/reme_ai/context/offload/context_offload_op.py @@ -0,0 +1,107 @@ +""" +Context offload module for managing context window limits through compaction and compression. + +This module provides a high-level operation that orchestrates context compaction and compression +to reduce token usage. It first attempts to compact tool messages, and if the compaction ratio +is not sufficient, it applies LLM-based compression to further reduce token count. + +The offload process: +1. Compacts tool messages by storing full content in external files +2. Evaluates the compaction effectiveness by comparing token counts +3. If compaction ratio exceeds threshold, applies LLM-based compression +""" + +from flowllm.core.context import C +from flowllm.core.op import BaseAsyncOp +from flowllm.core.schema import Message +from loguru import logger + +from reme_ai.enumeration import ContextManageEnum + + +@C.register_op() +class ContextOffloadOp(BaseAsyncOp): + """ + Context offload operation that orchestrates compaction and compression to reduce token usage. + + This operation combines context compaction and compression strategies to manage context + window limits. It first applies compaction to tool messages, then evaluates the effectiveness. + If the compaction ratio (compressed tokens / original tokens) exceeds a threshold, it + applies additional LLM-based compression to further reduce token count. + + Context Parameters: + context_manage_mode (ContextManageEnum): The context management mode to use. + - COMPACT: Only applies context compaction to tool messages. + - COMPRESS: Only applies LLM-based compression to messages. + - AUTO: Applies compaction first, then compression if compaction ratio exceeds threshold. + Defaults to AUTO. + compact_ratio_threshold (float): Threshold for compaction ratio above which compression + is applied. Only used in AUTO mode. Defaults to 0.75. If the ratio of compressed + tokens to original tokens exceeds this value, compression will be triggered. + """ + + async def async_execute(self): + """ + Execute the context offload operation. + + The operation behavior depends on the context_manage_mode: + - COMPACT: Only applies context compaction to reduce token usage in tool messages. + - COMPRESS: Only applies LLM-based compression to generate concise summaries. + - AUTO: Applies compaction first, then compression if compaction ratio exceeds threshold. + + The compaction operation stores full tool message content in external files and + keeps only previews in the context. The compression operation uses LLM to generate + concise summaries of older messages. + """ + from .context_compact_op import ContextCompactOp + from .context_compress_op import ContextCompressOp + + # Get the context management mode from context, default to AUTO + context_manage_mode = self.context.get("context_manage_mode", ContextManageEnum.AUTO) + if isinstance(context_manage_mode, str): + context_manage_mode = ContextManageEnum(context_manage_mode) + + context_compact_op = ContextCompactOp() + context_compress_op = ContextCompressOp() + + if context_manage_mode == ContextManageEnum.COMPACT: + # Only apply compaction + logger.info("Context management mode: COMPACT") + await context_compact_op.async_call(context=self.context) + elif context_manage_mode == ContextManageEnum.COMPRESS: + # Only apply compression + logger.info("Context management mode: COMPRESS") + await context_compress_op.async_call(context=self.context) + elif context_manage_mode == ContextManageEnum.AUTO: + # Apply compaction first, then compression if needed + logger.info("Context management mode: AUTO") + await context_compact_op.async_call(context=self.context) + + origin_messages = [Message(**x) for x in self.context.messages] + origin_token_cnt = self.token_count(origin_messages) + + result_messages = [Message(**x) for x in self.context.response.answer] + answer_token_cnt = self.token_count(result_messages) + + compact_ratio = answer_token_cnt / origin_token_cnt + + compact_ratio_threshold: float = self.context.get("compact_ratio_threshold", 0.75) + if compact_ratio > compact_ratio_threshold: + logger.info(f"Compact ratio {compact_ratio:.2f} > {compact_ratio_threshold:.2f}, compress answer") + await context_compress_op.async_call(context=self.context) + else: + raise ValueError(f"Unknown context management mode: {context_manage_mode}") + + async def async_default_execute(self, e: Exception = None, **_kwargs): + """Handle execution errors by returning original messages. + + This method is called when an exception occurs during async_execute. It preserves + the original messages and marks the operation as unsuccessful. + + Args: + e: The exception that occurred during execution, if any. + **_kwargs: Additional keyword arguments (unused but required by interface). + """ + self.context.response.answer = self.context.messages + self.context.response.success = False + self.context.response.metadata["error"] = str(e) diff --git a/reme_ai/enumeration/__init__.py b/reme_ai/enumeration/__init__.py index a6466cf3..482d29e1 100644 --- a/reme_ai/enumeration/__init__.py +++ b/reme_ai/enumeration/__init__.py @@ -4,8 +4,10 @@ including language enumerations and other type definitions. """ +from reme_ai.enumeration.context_manage_enum import ContextManageEnum from reme_ai.enumeration.language_enum import LanguageEnum __all__ = [ + "ContextManageEnum", "LanguageEnum", ] diff --git a/reme_ai/enumeration/context_manage_enum.py b/reme_ai/enumeration/context_manage_enum.py new file mode 100644 index 00000000..3c290369 --- /dev/null +++ b/reme_ai/enumeration/context_manage_enum.py @@ -0,0 +1,21 @@ +"""Context management enumeration module. + +This module provides enumerations for context management strategies in the ReMe system. +""" + +from enum import Enum + + +class ContextManageEnum(str, Enum): + """ + An enumeration representing context management strategies. + + Members: + - COMPACT: Represents the compact context management strategy. + - COMPRESS: Represents the compress context management strategy. + - AUTO: Represents the automatic context management strategy. + """ + + COMPACT = "compact" + COMPRESS = "compress" + AUTO = "auto" diff --git a/reme_ai/utils/op_utils.py b/reme_ai/utils/op_utils.py index 8345d008..a4d86645 100644 --- a/reme_ai/utils/op_utils.py +++ b/reme_ai/utils/op_utils.py @@ -130,3 +130,24 @@ def parse_update_insight_response(response_text: str, language: str = "en") -> s logger.warning("No insight content found in response") return "" + + +def extract_xml_tag_content(text: str, tag_name: str) -> str | None: + """Extract content from XML tag in text. + + Args: + text: The text containing XML tags. + tag_name: The name of the XML tag to extract (e.g., 'state_snapshot'). + + Returns: + str: The content inside the XML tag, or None if not found. + """ + # Use re.DOTALL to make . match newline characters + pattern = rf"<{tag_name}>(.*?)" + match = re.search(pattern, text, re.DOTALL) + + if match: + content = match.group(1).strip() + return content + + return None diff --git a/test/test_context_compress_op.py b/test/test_context_compress_op.py index 27203f71..520ed033 100644 --- a/test/test_context_compress_op.py +++ b/test/test_context_compress_op.py @@ -222,12 +222,13 @@ async def main(): logger.info("Test 1: Messages below threshold (should skip compression)") logger.info("=" * 60) - compress_op1 = ContextCompressOp( - all_token_threshold=50000, # High threshold, won't trigger - keep_recent=3, - ) + compress_op1 = ContextCompressOp() - await compress_op1.async_call(messages=messages) + await compress_op1.async_call( + messages=messages, + max_total_tokens=50000, # High threshold, won't trigger + keep_recent_count=3, + ) result_messages1 = compress_op1.context.response.answer logger.info(f"✓ Result: {len(result_messages1)} messages (unchanged)") @@ -237,14 +238,15 @@ async def main(): logger.info("Test 2: Messages above threshold (should compress)") logger.info("=" * 60) - compress_op2 = ContextCompressOp( - all_token_threshold=2000, # Low threshold, will trigger - keep_recent=3, # Keep last 3 messages + compress_op2 = ContextCompressOp() + + await compress_op2.async_call( + messages=messages, + max_total_tokens=2000, # Low threshold, will trigger + keep_recent_count=3, # Keep last 3 messages compress_system_message=False, # Don't compress system messages ) - await compress_op2.async_call(messages=messages) - result_messages2 = compress_op2.context.response.answer logger.info(f"✓ Result: {len(result_messages2)} messages (compressed)") @@ -260,16 +262,17 @@ async def main(): logger.info("Test 3: Messages above micro threshold (should compress)") logger.info("=!" * 30) - compress_op2 = ContextCompressOp( - all_token_threshold=2000, # Low threshold, will trigger - keep_recent=2, # Keep last 3 messages + compress_op2 = ContextCompressOp() + + await compress_op2.async_call( + messages=messages, + max_total_tokens=2000, # Low threshold, will trigger + keep_recent_count=2, # Keep last 2 messages compress_system_message=False, # Don't compress system messages - micro_summary_token_threshold=1500, + group_token_threshold=1500, language="zh", ) - await compress_op2.async_call(messages=messages) - result_messages2 = compress_op2.context.response.answer logger.info(f"✓ Result: {len(result_messages2)} messages (compressed)") diff --git a/test_op/test_context_compact_op.py b/test_op/test_context_compact_op.py index 5cdf8b96..61a69e11 100644 --- a/test_op/test_context_compact_op.py +++ b/test_op/test_context_compact_op.py @@ -59,16 +59,17 @@ async def async_main(): ] # Create op with lower thresholds for testing - op = ContextCompactOp( - all_token_threshold=1000, # Low threshold to trigger compaction - tool_token_threshold=100, # Low threshold to compact tool messages - tool_left_char_len=50, # Keep 50 chars in preview - keep_recent=1, # Keep 1 recent tool message - storage_path="./test_compact_storage", - ) + op = ContextCompactOp() # Execute the compaction - await op.async_call(messages=[m.model_dump() for m in messages]) + await op.async_call( + messages=[m.model_dump() for m in messages], + max_total_tokens=1000, # Low threshold to trigger compaction + max_tool_message_tokens=100, # Low threshold to compact tool messages + preview_char_length=50, # Keep 50 chars in preview + keep_recent_count=1, # Keep 1 recent tool message + storage_path="./test_compact_storage", + ) # Print results result = op.context.response.answer