|
3 | 3 | from typing import Dict, Optional, List
|
4 | 4 | from autogen import ConversableAgent
|
5 | 5 | from autogen import token_count_utils
|
| 6 | +import tiktoken |
6 | 7 |
|
7 | 8 |
|
8 | 9 | class TransformChatHistory:
|
@@ -53,56 +54,70 @@ def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
|
53 | 54 | messages: List of messages to process.
|
54 | 55 |
|
55 | 56 | Returns:
|
56 |
| - List of messages with the first system message and the last max_messages messages. |
| 57 | + List of messages with the first system message and the last max_messages messages, |
| 58 | + ensuring each message does not exceed max_tokens_per_message. |
57 | 59 | """
|
| 60 | + temp_messages = messages.copy() |
58 | 61 | processed_messages = []
|
59 |
| - messages = messages.copy() |
60 |
| - rest_messages = messages |
| 62 | + system_message = None |
| 63 | + processed_messages_tokens = 0 |
61 | 64 |
|
62 |
| - # check if the first message is a system message and append it to the processed messages |
63 |
| - if len(messages) > 0: |
64 |
| - if messages[0]["role"] == "system": |
65 |
| - msg = messages[0] |
66 |
| - processed_messages.append(msg) |
67 |
| - rest_messages = messages[1:] |
| 65 | + if messages[0]["role"] == "system": |
| 66 | + system_message = messages[0].copy() |
| 67 | + temp_messages.pop(0) |
68 | 68 |
|
69 |
| - processed_messages_tokens = 0 |
70 |
| - for msg in messages: |
71 |
| - msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message) |
| 69 | + total_tokens = sum( |
| 70 | + token_count_utils.count_token(msg["content"]) for msg in temp_messages |
| 71 | + ) # Calculate tokens for all messages |
72 | 72 |
|
73 |
| - # iterate through rest of the messages and append them to the processed messages |
74 |
| - for msg in rest_messages[-self.max_messages :]: |
| 73 | + # Truncate each message's content to a maximum token limit of each message |
| 74 | + |
| 75 | + # Process recent messages first |
| 76 | + for msg in reversed(temp_messages[-self.max_messages :]): |
| 77 | + msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message) |
75 | 78 | msg_tokens = token_count_utils.count_token(msg["content"])
|
76 | 79 | if processed_messages_tokens + msg_tokens > self.max_tokens:
|
77 | 80 | break
|
78 |
| - processed_messages.append(msg) |
| 81 | + # append the message to the beginning of the list to preserve order |
| 82 | + processed_messages = [msg] + processed_messages |
79 | 83 | processed_messages_tokens += msg_tokens
|
80 |
| - |
81 |
| - total_tokens = 0 |
82 |
| - for msg in messages: |
83 |
| - total_tokens += token_count_utils.count_token(msg["content"]) |
84 |
| - |
| 84 | + if system_message: |
| 85 | + processed_messages.insert(0, system_message) |
| 86 | + # Optionally, log the number of truncated messages and tokens if needed |
85 | 87 | num_truncated = len(messages) - len(processed_messages)
|
| 88 | + |
86 | 89 | if num_truncated > 0 or total_tokens > processed_messages_tokens:
|
87 |
| - print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow")) |
88 |
| - print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow")) |
| 90 | + print( |
| 91 | + colored( |
| 92 | + f"Truncated {num_truncated} messages. Reduced from {len(messages)} to {len(processed_messages)}.", |
| 93 | + "yellow", |
| 94 | + ) |
| 95 | + ) |
| 96 | + print( |
| 97 | + colored( |
| 98 | + f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}", |
| 99 | + "yellow", |
| 100 | + ) |
| 101 | + ) |
89 | 102 | return processed_messages
|
90 | 103 |
|
91 | 104 |
|
92 |
| -def truncate_str_to_tokens(text: str, max_tokens: int) -> str: |
93 |
| - """ |
94 |
| - Truncate a string so that number of tokens in less than max_tokens. |
| 105 | +def truncate_str_to_tokens(text: str, max_tokens: int, model: str = "gpt-3.5-turbo-0613") -> str: |
| 106 | + """Truncate a string so that the number of tokens is less than or equal to max_tokens using tiktoken. |
95 | 107 |
|
96 | 108 | Args:
|
97 |
| - content: String to process. |
98 |
| - max_tokens: Maximum number of tokens to keep. |
| 109 | + text: The string to truncate. |
| 110 | + max_tokens: The maximum number of tokens to keep. |
| 111 | + model: The target OpenAI model for tokenization alignment. |
99 | 112 |
|
100 | 113 | Returns:
|
101 |
| - Truncated string. |
| 114 | + The truncated string. |
102 | 115 | """
|
103 |
| - truncated_string = "" |
104 |
| - for char in text: |
105 |
| - truncated_string += char |
106 |
| - if token_count_utils.count_token(truncated_string) == max_tokens: |
107 |
| - break |
108 |
| - return truncated_string |
| 116 | + |
| 117 | + encoding = tiktoken.encoding_for_model(model) # Get the appropriate tokenizer |
| 118 | + |
| 119 | + encoded_tokens = encoding.encode(text) |
| 120 | + truncated_tokens = encoded_tokens[:max_tokens] |
| 121 | + truncated_text = encoding.decode(truncated_tokens) # Decode back to text |
| 122 | + |
| 123 | + return truncated_text |
0 commit comments