Skip to content

Commit 376f337

Browse files
committed
feat: add list input support for chat messages with unit tests
1 parent 9712f2b commit 376f337

File tree

3 files changed

+131
-19
lines changed

3 files changed

+131
-19
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
from autogen_core import CancellationToken
55

66
from ..base import ChatAgent, Response, TaskResult
7-
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
7+
from ..messages import (
8+
AgentMessage,
9+
ChatMessage,
10+
HandoffMessage,
11+
MultiModalMessage,
12+
StopMessage,
13+
TextMessage,
14+
)
815
from ..state import BaseState
916

1017

@@ -45,8 +52,9 @@ async def on_messages_stream(
4552
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4653
) -> AsyncGenerator[AgentMessage | Response, None]:
4754
"""Handles incoming messages and returns a stream of messages and
48-
and the final item is the response. The base implementation in :class:`BaseChatAgent`
49-
simply calls :meth:`on_messages` and yields the messages in the response."""
55+
and the final item is the response. The base implementation in
56+
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
57+
the messages in the response."""
5058
response = await self.on_messages(messages, cancellation_token)
5159
for inner_message in response.inner_messages or []:
5260
yield inner_message
@@ -69,6 +77,13 @@ async def run(
6977
text_msg = TextMessage(content=task, source="user")
7078
input_messages.append(text_msg)
7179
output_messages.append(text_msg)
80+
elif isinstance(task, list):
81+
for msg in task:
82+
if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)):
83+
input_messages.append(msg)
84+
output_messages.append(msg)
85+
else:
86+
raise ValueError(f"Invalid message type in list: {type(msg)}")
7287
elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)):
7388
input_messages.append(task)
7489
output_messages.append(task)

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
ToolCallMessage,
2323
ToolCallResultMessage,
2424
)
25-
2625
from ....state import MagenticOneOrchestratorState
2726
from .._base_group_chat_manager import BaseGroupChatManager
2827
from .._events import (

python/packages/autogen-agentchat/tests/test_assistant_agent.py

+113-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from autogen_agentchat.agents import AssistantAgent
99
from autogen_agentchat.base import Handoff, TaskResult
1010
from autogen_agentchat.messages import (
11+
ChatMessage,
1112
HandoffMessage,
1213
MultiModalMessage,
1314
TextMessage,
@@ -21,7 +22,10 @@
2122
from openai.types.chat.chat_completion import ChatCompletion, Choice
2223
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
2324
from openai.types.chat.chat_completion_message import ChatCompletionMessage
24-
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
25+
from openai.types.chat.chat_completion_message_tool_call import (
26+
ChatCompletionMessageToolCall,
27+
Function,
28+
)
2529
from openai.types.completion_usage import CompletionUsage
2630
from utils import FileLogHandler
2731

@@ -33,14 +37,14 @@
3337
class _MockChatCompletion:
3438
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
3539
self._saved_chat_completions = chat_completions
36-
self._curr_index = 0
40+
self.curr_index = 0
3741

3842
async def mock_create(
3943
self, *args: Any, **kwargs: Any
4044
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
4145
await asyncio.sleep(0.1)
42-
completion = self._saved_chat_completions[self._curr_index]
43-
self._curr_index += 1
46+
completion = self._saved_chat_completions[self.curr_index]
47+
self.curr_index += 1
4448
return completion
4549

4650

@@ -90,7 +94,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
9094
ChatCompletion(
9195
id="id2",
9296
choices=[
93-
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
97+
Choice(
98+
finish_reason="stop",
99+
index=0,
100+
message=ChatCompletionMessage(content="Hello", role="assistant"),
101+
)
94102
],
95103
created=0,
96104
model=model,
@@ -101,7 +109,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
101109
id="id2",
102110
choices=[
103111
Choice(
104-
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
112+
finish_reason="stop",
113+
index=0,
114+
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
105115
)
106116
],
107117
created=0,
@@ -115,7 +125,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
115125
agent = AssistantAgent(
116126
"tool_use_agent",
117127
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
118-
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
128+
tools=[
129+
_pass_function,
130+
_fail_function,
131+
FunctionTool(_echo_function, description="Echo"),
132+
],
119133
)
120134
result = await agent.run(task="task")
121135
assert len(result.messages) == 4
@@ -133,7 +147,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
133147
assert result.messages[3].models_usage.prompt_tokens == 10
134148

135149
# Test streaming.
136-
mock._curr_index = 0 # pyright: ignore
150+
mock.curr_index = 0 # pyright: ignore
137151
index = 0
138152
async for message in agent.run_stream(task="task"):
139153
if isinstance(message, TaskResult):
@@ -147,7 +161,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
147161
agent2 = AssistantAgent(
148162
"tool_use_agent",
149163
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
150-
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
164+
tools=[
165+
_pass_function,
166+
_fail_function,
167+
FunctionTool(_echo_function, description="Echo"),
168+
],
151169
)
152170
await agent2.load_state(state)
153171
state2 = await agent2.save_state()
@@ -192,7 +210,11 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
192210
tool_use_agent = AssistantAgent(
193211
"tool_use_agent",
194212
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
195-
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
213+
tools=[
214+
_pass_function,
215+
_fail_function,
216+
FunctionTool(_echo_function, description="Echo"),
217+
],
196218
handoffs=[handoff],
197219
)
198220
assert HandoffMessage in tool_use_agent.produced_message_types
@@ -212,7 +234,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
212234
assert result.messages[3].models_usage is None
213235

214236
# Test streaming.
215-
mock._curr_index = 0 # pyright: ignore
237+
mock.curr_index = 0 # pyright: ignore
216238
index = 0
217239
async for message in tool_use_agent.run_stream(task="task"):
218240
if isinstance(message, TaskResult):
@@ -229,7 +251,11 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
229251
ChatCompletion(
230252
id="id2",
231253
choices=[
232-
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
254+
Choice(
255+
finish_reason="stop",
256+
index=0,
257+
message=ChatCompletionMessage(content="Hello", role="assistant"),
258+
)
233259
],
234260
created=0,
235261
model=model,
@@ -239,7 +265,10 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
239265
]
240266
mock = _MockChatCompletion(chat_completions)
241267
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
242-
agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key=""))
268+
agent = AssistantAgent(
269+
name="assistant",
270+
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
271+
)
243272
# Generate a random base64 image.
244273
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
245274
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
@@ -250,14 +279,24 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
250279
async def test_invalid_model_capabilities() -> None:
251280
model = "random-model"
252281
model_client = OpenAIChatCompletionClient(
253-
model=model, api_key="", model_capabilities={"vision": False, "function_calling": False, "json_output": False}
282+
model=model,
283+
api_key="",
284+
model_capabilities={
285+
"vision": False,
286+
"function_calling": False,
287+
"json_output": False,
288+
},
254289
)
255290

256291
with pytest.raises(ValueError):
257292
agent = AssistantAgent(
258293
name="assistant",
259294
model_client=model_client,
260-
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
295+
tools=[
296+
_pass_function,
297+
_fail_function,
298+
FunctionTool(_echo_function, description="Echo"),
299+
],
261300
)
262301

263302
with pytest.raises(ValueError):
@@ -268,3 +307,62 @@ async def test_invalid_model_capabilities() -> None:
268307
# Generate a random base64 image.
269308
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
270309
await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
310+
311+
312+
@pytest.mark.asyncio
313+
async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
314+
model = "gpt-4o-2024-05-13"
315+
chat_completions = [
316+
ChatCompletion(
317+
id="id1",
318+
choices=[
319+
Choice(
320+
finish_reason="stop",
321+
index=0,
322+
message=ChatCompletionMessage(content="Response to message 1", role="assistant"),
323+
)
324+
],
325+
created=0,
326+
model=model,
327+
object="chat.completion",
328+
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
329+
),
330+
]
331+
mock = _MockChatCompletion(chat_completions)
332+
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
333+
agent = AssistantAgent(
334+
"test_agent",
335+
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
336+
)
337+
338+
# Create a list of chat messages
339+
messages: List[ChatMessage] = [
340+
TextMessage(content="Message 1", source="user"),
341+
TextMessage(content="Message 2", source="user"),
342+
]
343+
344+
# Test run method with list of messages
345+
result = await agent.run(task=messages)
346+
assert len(result.messages) == 3 # 2 input messages + 1 response message
347+
assert isinstance(result.messages[0], TextMessage)
348+
assert result.messages[0].content == "Message 1"
349+
assert result.messages[0].source == "user"
350+
assert isinstance(result.messages[1], TextMessage)
351+
assert result.messages[1].content == "Message 2"
352+
assert result.messages[1].source == "user"
353+
assert isinstance(result.messages[2], TextMessage)
354+
assert result.messages[2].content == "Response to message 1"
355+
assert result.messages[2].source == "test_agent"
356+
assert result.messages[2].models_usage is not None
357+
assert result.messages[2].models_usage.completion_tokens == 5
358+
assert result.messages[2].models_usage.prompt_tokens == 10
359+
360+
# Test run_stream method with list of messages
361+
mock.curr_index = 0 # Reset mock index using public attribute
362+
index = 0
363+
async for message in agent.run_stream(task=messages):
364+
if isinstance(message, TaskResult):
365+
assert message == result
366+
else:
367+
assert message == result.messages[index]
368+
index += 1

0 commit comments

Comments
 (0)