|
1 |
| -import copy |
2 | 1 | import os
|
3 | 2 | import sys
|
4 | 3 | import tempfile
|
|
7 | 6 | import pytest
|
8 | 7 |
|
9 | 8 | import autogen
|
10 |
| -from autogen import token_count_utils |
11 | 9 | from autogen.agentchat.contrib.capabilities.transform_messages import TransformMessages
|
12 | 10 | from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter
|
13 | 11 |
|
|
18 | 16 | from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
19 | 17 |
|
20 | 18 |
|
21 |
| -def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int: |
22 |
| - token_count = 0 |
23 |
| - if isinstance(content, str): |
24 |
| - token_count = token_count_utils.count_token(content) |
25 |
| - elif isinstance(content, list): |
26 |
| - for item in content: |
27 |
| - token_count += _count_tokens(item.get("text", "")) |
28 |
| - return token_count |
29 |
| - |
30 |
| - |
31 |
| -def test_limit_token_transform(): |
32 |
| - """ |
33 |
| - Test the TokenLimitTransform capability. |
34 |
| - """ |
35 |
| - |
36 |
| - messages = [ |
37 |
| - {"role": "user", "content": "short string"}, |
38 |
| - { |
39 |
| - "role": "assistant", |
40 |
| - "content": [{"type": "text", "text": "very very very very very very very very long string"}], |
41 |
| - }, |
42 |
| - ] |
43 |
| - |
44 |
| - # check if token limit per message is not exceeded. |
45 |
| - max_tokens_per_message = 5 |
46 |
| - token_limit_transform = MessageTokenLimiter(max_tokens_per_message=max_tokens_per_message) |
47 |
| - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) |
48 |
| - |
49 |
| - for message in transformed_messages: |
50 |
| - assert _count_tokens(message["content"]) <= max_tokens_per_message |
51 |
| - |
52 |
| - # check if total token limit is not exceeded. |
53 |
| - max_tokens = 10 |
54 |
| - token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens) |
55 |
| - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) |
56 |
| - |
57 |
| - token_count = 0 |
58 |
| - for message in transformed_messages: |
59 |
| - token_count += _count_tokens(message["content"]) |
60 |
| - |
61 |
| - assert token_count <= max_tokens |
62 |
| - assert len(transformed_messages) <= len(messages) |
63 |
| - |
64 |
| - # check if token limit per message works nicely with total token limit. |
65 |
| - token_limit_transform = MessageTokenLimiter(max_tokens=max_tokens, max_tokens_per_message=max_tokens_per_message) |
66 |
| - |
67 |
| - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) |
68 |
| - |
69 |
| - token_count = 0 |
70 |
| - for message in transformed_messages: |
71 |
| - token_count_local = _count_tokens(message["content"]) |
72 |
| - token_count += token_count_local |
73 |
| - assert token_count_local <= max_tokens_per_message |
74 |
| - |
75 |
| - assert token_count <= max_tokens |
76 |
| - assert len(transformed_messages) <= len(messages) |
77 |
| - |
78 |
| - |
79 |
| -def test_limit_token_transform_without_content(): |
80 |
| - """Test the TokenLimitTransform with messages that don't have content.""" |
81 |
| - |
82 |
| - messages = [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}] |
83 |
| - |
84 |
| - # check if token limit per message works nicely with total token limit. |
85 |
| - token_limit_transform = MessageTokenLimiter(max_tokens=10, max_tokens_per_message=5) |
86 |
| - |
87 |
| - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) |
88 |
| - |
89 |
| - assert len(transformed_messages) == len(messages) |
90 |
| - |
91 |
| - |
92 |
| -def test_limit_token_transform_total_token_count(): |
93 |
| - """Tests if the TokenLimitTransform truncates without dropping messages.""" |
94 |
| - messages = [{"role": "very very very very very"}] |
95 |
| - |
96 |
| - token_limit_transform = MessageTokenLimiter(max_tokens=1) |
97 |
| - transformed_messages = token_limit_transform.apply_transform(copy.deepcopy(messages)) |
98 |
| - |
99 |
| - assert len(transformed_messages) == 1 |
100 |
| - |
101 |
| - |
102 |
| -def test_max_message_history_length_transform(): |
103 |
| - """ |
104 |
| - Test the MessageHistoryLimiter capability to limit the number of messages. |
105 |
| - """ |
106 |
| - messages = [ |
107 |
| - {"role": "user", "content": "hello"}, |
108 |
| - {"role": "assistant", "content": [{"type": "text", "text": "there"}]}, |
109 |
| - {"role": "user", "content": "how"}, |
110 |
| - {"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]}, |
111 |
| - ] |
112 |
| - |
113 |
| - max_messages = 2 |
114 |
| - messages_limiter = MessageHistoryLimiter(max_messages=max_messages) |
115 |
| - transformed_messages = messages_limiter.apply_transform(copy.deepcopy(messages)) |
116 |
| - |
117 |
| - assert len(transformed_messages) == max_messages |
118 |
| - assert transformed_messages == messages[max_messages:] |
119 |
| - |
120 |
| - |
121 | 19 | @pytest.mark.skipif(skip_openai, reason="Requested to skip openai test.")
|
122 | 20 | def test_transform_messages_capability():
|
123 | 21 | """Test the TransformMessages capability to handle long contexts.
|
@@ -172,6 +70,4 @@ def test_transform_messages_capability():
|
172 | 70 |
|
173 | 71 |
|
174 | 72 | if __name__ == "__main__":
|
175 |
| - test_limit_token_transform() |
176 |
| - test_max_message_history_length_transform() |
177 | 73 | test_transform_messages_capability()
|
0 commit comments