|
| 1 | +import sys |
| 2 | +from termcolor import colored |
| 3 | +from typing import Dict, Optional, List |
| 4 | +from autogen import ConversableAgent |
| 5 | +from autogen import token_count_utils |
| 6 | + |
| 7 | + |
| 8 | +class TransformChatHistory: |
| 9 | + """ |
| 10 | + An agent's chat history with other agents is a common context that it uses to generate a reply. |
| 11 | + This capability allows the agent to transform its chat history prior to using it to generate a reply. |
| 12 | + It does not permanently modify the chat history, but rather processes it on every invocation. |
| 13 | +
|
| 14 | + This capability class enables various strategies to transform chat history, such as: |
| 15 | + - Truncate messages: Truncate each message to first maximum number of tokens. |
| 16 | + - Limit number of messages: Truncate the chat history to a maximum number of (recent) messages. |
| 17 | + - Limit number of tokens: Truncate the chat history to number of recent N messages that fit in |
| 18 | + maximum number of tokens. |
| 19 | + Note that the system message, because of its special significance, is always kept as is. |
| 20 | +
|
| 21 | + The three strategies can be combined. For example, when each of these parameters are specified |
| 22 | + they are used in the following order: |
| 23 | + 1. First truncate messages to a maximum number of tokens |
| 24 | + 2. Second, it limits the number of message to keep |
| 25 | + 3. Third, it limits the total number of tokens in the chat history |
| 26 | +
|
| 27 | + Args: |
| 28 | + max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message. |
| 29 | + max_messages (Optional[int]): Maximum number of messages to keep in the context. |
| 30 | + max_tokens (Optional[int]): Maximum number of tokens to keep in the context. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + *, |
| 36 | + max_tokens_per_message: Optional[int] = None, |
| 37 | + max_messages: Optional[int] = None, |
| 38 | + max_tokens: Optional[int] = None, |
| 39 | + ): |
| 40 | + self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize |
| 41 | + self.max_messages = max_messages if max_messages else sys.maxsize |
| 42 | + self.max_tokens = max_tokens if max_tokens else sys.maxsize |
| 43 | + |
| 44 | + def add_to_agent(self, agent: ConversableAgent): |
| 45 | + """ |
| 46 | + Adds TransformChatHistory capability to the given agent. |
| 47 | + """ |
| 48 | + agent.register_hook(hookable_method=agent.process_all_messages, hook=self._transform_messages) |
| 49 | + |
| 50 | + def _transform_messages(self, messages: List[Dict]) -> List[Dict]: |
| 51 | + """ |
| 52 | + Args: |
| 53 | + messages: List of messages to process. |
| 54 | +
|
| 55 | + Returns: |
| 56 | + List of messages with the first system message and the last max_messages messages. |
| 57 | + """ |
| 58 | + processed_messages = [] |
| 59 | + messages = messages.copy() |
| 60 | + rest_messages = messages |
| 61 | + |
| 62 | + # check if the first message is a system message and append it to the processed messages |
| 63 | + if len(messages) > 0: |
| 64 | + if messages[0]["role"] == "system": |
| 65 | + msg = messages[0] |
| 66 | + processed_messages.append(msg) |
| 67 | + rest_messages = messages[1:] |
| 68 | + |
| 69 | + processed_messages_tokens = 0 |
| 70 | + for msg in messages: |
| 71 | + msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message) |
| 72 | + |
| 73 | + # iterate through rest of the messages and append them to the processed messages |
| 74 | + for msg in rest_messages[-self.max_messages :]: |
| 75 | + msg_tokens = token_count_utils.count_token(msg["content"]) |
| 76 | + if processed_messages_tokens + msg_tokens > self.max_tokens: |
| 77 | + break |
| 78 | + processed_messages.append(msg) |
| 79 | + processed_messages_tokens += msg_tokens |
| 80 | + |
| 81 | + total_tokens = 0 |
| 82 | + for msg in messages: |
| 83 | + total_tokens += token_count_utils.count_token(msg["content"]) |
| 84 | + |
| 85 | + num_truncated = len(messages) - len(processed_messages) |
| 86 | + if num_truncated > 0 or total_tokens > processed_messages_tokens: |
| 87 | + print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow")) |
| 88 | + print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow")) |
| 89 | + return processed_messages |
| 90 | + |
| 91 | + |
| 92 | +def truncate_str_to_tokens(text: str, max_tokens: int) -> str: |
| 93 | + """ |
| 94 | + Truncate a string so that number of tokens in less than max_tokens. |
| 95 | +
|
| 96 | + Args: |
| 97 | + content: String to process. |
| 98 | + max_tokens: Maximum number of tokens to keep. |
| 99 | +
|
| 100 | + Returns: |
| 101 | + Truncated string. |
| 102 | + """ |
| 103 | + truncated_string = "" |
| 104 | + for char in text: |
| 105 | + truncated_string += char |
| 106 | + if token_count_utils.count_token(truncated_string) == max_tokens: |
| 107 | + break |
| 108 | + return truncated_string |
0 commit comments