8
8
from autogen_agentchat .agents import AssistantAgent
9
9
from autogen_agentchat .base import Handoff , TaskResult
10
10
from autogen_agentchat .messages import (
11
+ ChatMessage ,
11
12
HandoffMessage ,
12
13
MultiModalMessage ,
13
14
TextMessage ,
21
22
from openai .types .chat .chat_completion import ChatCompletion , Choice
22
23
from openai .types .chat .chat_completion_chunk import ChatCompletionChunk
23
24
from openai .types .chat .chat_completion_message import ChatCompletionMessage
24
- from openai .types .chat .chat_completion_message_tool_call import ChatCompletionMessageToolCall , Function
25
+ from openai .types .chat .chat_completion_message_tool_call import (
26
+ ChatCompletionMessageToolCall ,
27
+ Function ,
28
+ )
25
29
from openai .types .completion_usage import CompletionUsage
26
30
from utils import FileLogHandler
27
31
33
37
class _MockChatCompletion :
34
38
def __init__ (self , chat_completions : List [ChatCompletion ]) -> None :
35
39
self ._saved_chat_completions = chat_completions
36
- self ._curr_index = 0
40
+ self .curr_index = 0
37
41
38
42
async def mock_create (
39
43
self , * args : Any , ** kwargs : Any
40
44
) -> ChatCompletion | AsyncGenerator [ChatCompletionChunk , None ]:
41
45
await asyncio .sleep (0.1 )
42
- completion = self ._saved_chat_completions [self ._curr_index ]
43
- self ._curr_index += 1
46
+ completion = self ._saved_chat_completions [self .curr_index ]
47
+ self .curr_index += 1
44
48
return completion
45
49
46
50
@@ -90,7 +94,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
90
94
ChatCompletion (
91
95
id = "id2" ,
92
96
choices = [
93
- Choice (finish_reason = "stop" , index = 0 , message = ChatCompletionMessage (content = "Hello" , role = "assistant" ))
97
+ Choice (
98
+ finish_reason = "stop" ,
99
+ index = 0 ,
100
+ message = ChatCompletionMessage (content = "Hello" , role = "assistant" ),
101
+ )
94
102
],
95
103
created = 0 ,
96
104
model = model ,
@@ -101,7 +109,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
101
109
id = "id2" ,
102
110
choices = [
103
111
Choice (
104
- finish_reason = "stop" , index = 0 , message = ChatCompletionMessage (content = "TERMINATE" , role = "assistant" )
112
+ finish_reason = "stop" ,
113
+ index = 0 ,
114
+ message = ChatCompletionMessage (content = "TERMINATE" , role = "assistant" ),
105
115
)
106
116
],
107
117
created = 0 ,
@@ -115,7 +125,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
115
125
agent = AssistantAgent (
116
126
"tool_use_agent" ,
117
127
model_client = OpenAIChatCompletionClient (model = model , api_key = "" ),
118
- tools = [_pass_function , _fail_function , FunctionTool (_echo_function , description = "Echo" )],
128
+ tools = [
129
+ _pass_function ,
130
+ _fail_function ,
131
+ FunctionTool (_echo_function , description = "Echo" ),
132
+ ],
119
133
)
120
134
result = await agent .run (task = "task" )
121
135
assert len (result .messages ) == 4
@@ -133,7 +147,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
133
147
assert result .messages [3 ].models_usage .prompt_tokens == 10
134
148
135
149
# Test streaming.
136
- mock ._curr_index = 0 # pyright: ignore
150
+ mock .curr_index = 0 # pyright: ignore
137
151
index = 0
138
152
async for message in agent .run_stream (task = "task" ):
139
153
if isinstance (message , TaskResult ):
@@ -147,7 +161,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
147
161
agent2 = AssistantAgent (
148
162
"tool_use_agent" ,
149
163
model_client = OpenAIChatCompletionClient (model = model , api_key = "" ),
150
- tools = [_pass_function , _fail_function , FunctionTool (_echo_function , description = "Echo" )],
164
+ tools = [
165
+ _pass_function ,
166
+ _fail_function ,
167
+ FunctionTool (_echo_function , description = "Echo" ),
168
+ ],
151
169
)
152
170
await agent2 .load_state (state )
153
171
state2 = await agent2 .save_state ()
@@ -192,7 +210,11 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
192
210
tool_use_agent = AssistantAgent (
193
211
"tool_use_agent" ,
194
212
model_client = OpenAIChatCompletionClient (model = model , api_key = "" ),
195
- tools = [_pass_function , _fail_function , FunctionTool (_echo_function , description = "Echo" )],
213
+ tools = [
214
+ _pass_function ,
215
+ _fail_function ,
216
+ FunctionTool (_echo_function , description = "Echo" ),
217
+ ],
196
218
handoffs = [handoff ],
197
219
)
198
220
assert HandoffMessage in tool_use_agent .produced_message_types
@@ -212,7 +234,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
212
234
assert result .messages [3 ].models_usage is None
213
235
214
236
# Test streaming.
215
- mock ._curr_index = 0 # pyright: ignore
237
+ mock .curr_index = 0 # pyright: ignore
216
238
index = 0
217
239
async for message in tool_use_agent .run_stream (task = "task" ):
218
240
if isinstance (message , TaskResult ):
@@ -229,7 +251,11 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
229
251
ChatCompletion (
230
252
id = "id2" ,
231
253
choices = [
232
- Choice (finish_reason = "stop" , index = 0 , message = ChatCompletionMessage (content = "Hello" , role = "assistant" ))
254
+ Choice (
255
+ finish_reason = "stop" ,
256
+ index = 0 ,
257
+ message = ChatCompletionMessage (content = "Hello" , role = "assistant" ),
258
+ )
233
259
],
234
260
created = 0 ,
235
261
model = model ,
@@ -239,7 +265,10 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
239
265
]
240
266
mock = _MockChatCompletion (chat_completions )
241
267
monkeypatch .setattr (AsyncCompletions , "create" , mock .mock_create )
242
- agent = AssistantAgent (name = "assistant" , model_client = OpenAIChatCompletionClient (model = model , api_key = "" ))
268
+ agent = AssistantAgent (
269
+ name = "assistant" ,
270
+ model_client = OpenAIChatCompletionClient (model = model , api_key = "" ),
271
+ )
243
272
# Generate a random base64 image.
244
273
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
245
274
result = await agent .run (task = MultiModalMessage (source = "user" , content = ["Test" , Image .from_base64 (img_base64 )]))
@@ -250,14 +279,24 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
250
279
async def test_invalid_model_capabilities () -> None :
251
280
model = "random-model"
252
281
model_client = OpenAIChatCompletionClient (
253
- model = model , api_key = "" , model_capabilities = {"vision" : False , "function_calling" : False , "json_output" : False }
282
+ model = model ,
283
+ api_key = "" ,
284
+ model_capabilities = {
285
+ "vision" : False ,
286
+ "function_calling" : False ,
287
+ "json_output" : False ,
288
+ },
254
289
)
255
290
256
291
with pytest .raises (ValueError ):
257
292
agent = AssistantAgent (
258
293
name = "assistant" ,
259
294
model_client = model_client ,
260
- tools = [_pass_function , _fail_function , FunctionTool (_echo_function , description = "Echo" )],
295
+ tools = [
296
+ _pass_function ,
297
+ _fail_function ,
298
+ FunctionTool (_echo_function , description = "Echo" ),
299
+ ],
261
300
)
262
301
263
302
with pytest .raises (ValueError ):
@@ -268,3 +307,62 @@ async def test_invalid_model_capabilities() -> None:
268
307
# Generate a random base64 image.
269
308
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
270
309
await agent .run (task = MultiModalMessage (source = "user" , content = ["Test" , Image .from_base64 (img_base64 )]))
310
+
311
+
312
+ @pytest .mark .asyncio
313
+ async def test_list_chat_messages (monkeypatch : pytest .MonkeyPatch ) -> None :
314
+ model = "gpt-4o-2024-05-13"
315
+ chat_completions = [
316
+ ChatCompletion (
317
+ id = "id1" ,
318
+ choices = [
319
+ Choice (
320
+ finish_reason = "stop" ,
321
+ index = 0 ,
322
+ message = ChatCompletionMessage (content = "Response to message 1" , role = "assistant" ),
323
+ )
324
+ ],
325
+ created = 0 ,
326
+ model = model ,
327
+ object = "chat.completion" ,
328
+ usage = CompletionUsage (prompt_tokens = 10 , completion_tokens = 5 , total_tokens = 15 ),
329
+ ),
330
+ ]
331
+ mock = _MockChatCompletion (chat_completions )
332
+ monkeypatch .setattr (AsyncCompletions , "create" , mock .mock_create )
333
+ agent = AssistantAgent (
334
+ "test_agent" ,
335
+ model_client = OpenAIChatCompletionClient (model = model , api_key = "" ),
336
+ )
337
+
338
+ # Create a list of chat messages
339
+ messages : List [ChatMessage ] = [
340
+ TextMessage (content = "Message 1" , source = "user" ),
341
+ TextMessage (content = "Message 2" , source = "user" ),
342
+ ]
343
+
344
+ # Test run method with list of messages
345
+ result = await agent .run (task = messages )
346
+ assert len (result .messages ) == 3 # 2 input messages + 1 response message
347
+ assert isinstance (result .messages [0 ], TextMessage )
348
+ assert result .messages [0 ].content == "Message 1"
349
+ assert result .messages [0 ].source == "user"
350
+ assert isinstance (result .messages [1 ], TextMessage )
351
+ assert result .messages [1 ].content == "Message 2"
352
+ assert result .messages [1 ].source == "user"
353
+ assert isinstance (result .messages [2 ], TextMessage )
354
+ assert result .messages [2 ].content == "Response to message 1"
355
+ assert result .messages [2 ].source == "test_agent"
356
+ assert result .messages [2 ].models_usage is not None
357
+ assert result .messages [2 ].models_usage .completion_tokens == 5
358
+ assert result .messages [2 ].models_usage .prompt_tokens == 10
359
+
360
+ # Test run_stream method with list of messages
361
+ mock .curr_index = 0 # Reset mock index using public attribute
362
+ index = 0
363
+ async for message in agent .run_stream (task = messages ):
364
+ if isinstance (message , TaskResult ):
365
+ assert message == result
366
+ else :
367
+ assert message == result .messages [index ]
368
+ index += 1
0 commit comments