From 6a222c9ac5b6f5bc34c681b3913e436da7d0fb42 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 31 Oct 2024 07:24:37 -0700 Subject: [PATCH 1/6] Initial stream APIs --- .../agents/_base_chat_agent.py | 12 +++++++++++- .../src/autogen_agentchat/base/_chat_agent.py | 18 +++++++----------- .../src/autogen_agentchat/base/_task.py | 12 +++++++++++- .../src/autogen_agentchat/base/_team.py | 14 ++------------ 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index ac74077e27c8..bbebb61edfcb 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Sequence +from typing import AsyncGenerator, List, Sequence from autogen_core.base import CancellationToken @@ -39,6 +39,16 @@ def produced_message_types(self) -> List[type[ChatMessage]]: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: """Handles incoming messages and returns a response.""" ... + + async def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[InnerMessage | ChatMessage | Response, None]: + """Handles incoming messages and returns a stream of messages and + and the final item is the response.""" + response = await self.on_messages(messages, cancellation_token) + for inner_message in response.inner_messages or []: + yield inner_message + yield response async def run( self, diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index d60dba349cbb..b6c76fa1f9be 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -1,11 +1,10 @@ from dataclasses import dataclass -from typing import List, Protocol, Sequence, runtime_checkable +from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable from autogen_core.base import CancellationToken from ..messages import ChatMessage, InnerMessage -from ._task import TaskResult, TaskRunner -from ._termination import TerminationCondition +from ._task import TaskRunner @dataclass(kw_only=True) @@ -45,12 +44,9 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: """Handles incoming messages and returns a response.""" ... - async def run( - self, - task: str, - *, - cancellation_token: CancellationToken | None = None, - termination_condition: TerminationCondition | None = None, - ) -> TaskResult: - """Run the agent with the given task and return the result.""" + def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[InnerMessage | ChatMessage | Response, None]: + """Handles incoming messages and returns a stream of messages and + and the final item is the response.""" ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index 326cceecb1fd..1b61fb6ffab0 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Protocol, Sequence +from typing import AsyncIterator, Protocol, Sequence from autogen_core.base import CancellationToken @@ -27,3 +27,13 @@ async def run( ) -> TaskResult: """Run the task.""" ... + + def run_stream( + self, + task: str, + *, + cancellation_token: CancellationToken | None = None, + ) -> AsyncIterator[InnerMessage | ChatMessage | TaskResult]: + """Run the task and produces a stream of messages and the final result + as the last item in the stream.""" + ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py index b0a1dc3d2a38..4028bb279887 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py @@ -1,18 +1,8 @@ from typing import Protocol -from autogen_core.base import CancellationToken -from ._task import TaskResult, TaskRunner -from ._termination import TerminationCondition +from ._task import TaskRunner class Team(TaskRunner, Protocol): - async def run( - self, - task: str, - *, - cancellation_token: CancellationToken | None = None, - termination_condition: TerminationCondition | None = None, - ) -> TaskResult: - """Run the team on a given task until the termination condition is met.""" - ... + pass From 305ae5eb41ccd4bf9e561ab8902d4aea7deefaa6 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 31 Oct 2024 08:43:10 -0700 Subject: [PATCH 2/6] WIP --- .../agents/_base_chat_agent.py | 22 ++++++++++++++++++- .../src/autogen_agentchat/base/_task.py | 4 ++-- .../teams/_group_chat/_base_group_chat.py | 10 ++++++++- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index bbebb61edfcb..03d9c712d03d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -44,7 +44,8 @@ async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[InnerMessage | ChatMessage | Response, None]: """Handles incoming messages and returns a stream of messages and - and the final item is the response.""" + and the final item is the response. The base implementation in :class:`BaseChatAgent` + simply calls :meth:`on_messages` and yields the messages in the response.""" response = await self.on_messages(messages, cancellation_token) for inner_message in response.inner_messages or []: yield inner_message @@ -67,3 +68,22 @@ async def run( messages += response.inner_messages messages.append(response.chat_message) return TaskResult(messages=messages) + + async def run_stream( + self, + task: str, + *, + cancellation_token: CancellationToken | None = None, + ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: + """Run the agent with the given task and return a stream of messages + and the final task result as the last item in the stream.""" + if cancellation_token is None: + cancellation_token = CancellationToken() + first_message = TextMessage(content=task, source="user") + messages: List[InnerMessage | ChatMessage] = [first_message] + async for message in self.on_messages_stream([first_message], cancellation_token): + if isinstance(message, Response): + yield TaskResult(messages=messages) + else: + messages.append(message) + yield message \ No newline at end of file diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index 1b61fb6ffab0..375c4aa7f2c5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import AsyncIterator, Protocol, Sequence +from typing import AsyncGenerator, Protocol, Sequence from autogen_core.base import CancellationToken @@ -33,7 +33,7 @@ def run_stream( task: str, *, cancellation_token: CancellationToken | None = None, - ) -> AsyncIterator[InnerMessage | ChatMessage | TaskResult]: + ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: """Run the task and produces a stream of messages and the final result as the last item in the stream.""" ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 9f3132a74955..1bdee193a6cd 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -1,6 +1,6 @@ import uuid from abc import ABC, abstractmethod -from typing import Callable, List +from typing import AsyncGenerator, Callable, List from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import ( @@ -169,3 +169,11 @@ async def collect_output_messages( # Return the result. return TaskResult(messages=output_messages) + + async def run_stream( + self, + task: str, + *, + cancellation_token: CancellationToken | None = None, + ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: + pass \ No newline at end of file From 4903bf9c9871d45d82c66f5c4ee2cc40df23499d Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 31 Oct 2024 10:48:22 -0700 Subject: [PATCH 3/6] WIP --- .../teams/_group_chat/_base_group_chat.py | 53 +++++++++++++------ .../_group_chat/_chat_agent_container.py | 34 ++++++------ 2 files changed, 54 insertions(+), 33 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 1bdee193a6cd..453ae486b6f2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -1,3 +1,4 @@ +import asyncio import uuid from abc import ABC, abstractmethod from typing import AsyncGenerator, Callable, List @@ -75,9 +76,22 @@ async def run( cancellation_token: CancellationToken | None = None, termination_condition: TerminationCondition | None = None, ) -> TaskResult: - """Run the team and return the result.""" - # Create intervention handler for termination. + """Run the team and return the result. The base implementation uses + :meth:`run_stream` to run the team and then returns the final result.""" + async for message in self.run_stream(task, cancellation_token=cancellation_token): + if isinstance(message, TaskResult): + return message + assert False, "The stream should have returned the final result." + + async def run_stream( + self, + task: str, + *, + cancellation_token: CancellationToken | None = None, + ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: + """Run the team and produces a stream of messages and the final result + as the last item in the stream.""" # Create the runtime. runtime = SingleThreadedAgentRuntime() @@ -117,7 +131,7 @@ async def run( group_topic_type=group_topic_type, participant_topic_types=participant_topic_types, participant_descriptions=participant_descriptions, - termination_condition=termination_condition or self._termination_condition, + termination_condition=self._termination_condition, ), ) # Add subscriptions for the group chat manager. @@ -132,6 +146,7 @@ async def run( ) output_messages: List[InnerMessage | ChatMessage] = [] + output_message_queue: asyncio.Queue[InnerMessage | ChatMessage | None] = asyncio.Queue() async def collect_output_messages( _runtime: AgentRuntime, @@ -140,6 +155,7 @@ async def collect_output_messages( ctx: MessageContext, ) -> None: output_messages.append(message) + await output_message_queue.put(message) await ClosureAgent.register( runtime, @@ -164,16 +180,21 @@ async def collect_output_messages( ) await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id) - # Wait for the runtime to stop. - await runtime.stop_when_idle() - - # Return the result. - return TaskResult(messages=output_messages) - - async def run_stream( - self, - task: str, - *, - cancellation_token: CancellationToken | None = None, - ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: - pass \ No newline at end of file + # Start a coroutine to stop the runtime and signal the output message queue is complete. + async def stop_runtime() -> None: + await runtime.stop_when_idle() + await output_message_queue.put(None) + shutdown_task = asyncio.create_task(stop_runtime()) + + # Yield the messsages until the queue is empty. + while True: + message = await output_message_queue.get() + if message is None: + break + yield message + + # Wait for the shutdown task to finish. + await shutdown_task + + # Yield the final result. + yield TaskResult(messages=output_messages) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index 1423735c2f7c..c05faf399763 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -3,7 +3,7 @@ from autogen_core.base import MessageContext from autogen_core.components import DefaultTopicId, event -from ...base import ChatAgent +from ...base import ChatAgent, Response from ...messages import ChatMessage from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent from ._sequential_routed_agent import SequentialRoutedAgent @@ -37,28 +37,28 @@ async def handle_content_request(self, message: GroupChatRequestPublishEvent, ct """Handle a content request event by passing the messages in the buffer to the delegate agent and publish the response.""" # Pass the messages in the buffer to the delegate agent. - response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) - if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types): - raise ValueError( - f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. " - f"Expected one of: {self._agent.produced_message_types}. " - f"Check the agent's produced_message_types property." - ) + response: Response | None = None + async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token): + if not any(isinstance(msg, msg_type) for msg_type in self._agent.produced_message_types): + raise ValueError( + f"The agent {self._agent.name} produced an unexpected message type: {type(msg)}. " + f"Expected one of: {self._agent.produced_message_types}. " + f"Check the agent's produced_message_types property." + ) + if isinstance(msg, Response): + response = msg + else: + # Publish the message to the output topic. + await self.publish_message(msg, topic_id=DefaultTopicId(type=self._output_topic_type)) + if response is None: + raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.") - # Publish inner messages to the output topic. - if response.inner_messages is not None: - for inner_message in response.inner_messages: - await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type)) - - # Publish the response. + # Publish the response to the group chat. self._message_buffer.clear() await self.publish_message( GroupChatPublishEvent(agent_message=response.chat_message, source=self.id), topic_id=DefaultTopicId(type=self._parent_topic_type), ) - # Publish the response to the output topic. - await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type)) - async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: raise ValueError(f"Unhandled message in agent container: {type(message)}") From 5fdfb5ca13fa310a88ef9a97ddc9abbe0ff3bd4b Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 31 Oct 2024 12:22:16 -0700 Subject: [PATCH 4/6] WIP --- .../agents/_base_chat_agent.py | 9 +- .../src/autogen_agentchat/base/_task.py | 4 +- .../src/autogen_agentchat/base/_team.py | 1 - .../teams/_group_chat/_base_group_chat.py | 16 ++- .../_group_chat/_chat_agent_container.py | 6 - .../tests/test_assistant_agent.py | 33 +++++- .../tests/test_group_chat.py | 108 ++++++++++++++++-- 7 files changed, 142 insertions(+), 35 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 03d9c712d03d..99f6f1a7e71a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -3,7 +3,7 @@ from autogen_core.base import CancellationToken -from ..base import ChatAgent, Response, TaskResult, TerminationCondition +from ..base import ChatAgent, Response, TaskResult from ..messages import ChatMessage, InnerMessage, TextMessage @@ -39,7 +39,7 @@ def produced_message_types(self) -> List[type[ChatMessage]]: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: """Handles incoming messages and returns a response.""" ... - + async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[InnerMessage | ChatMessage | Response, None]: @@ -49,6 +49,7 @@ async def on_messages_stream( response = await self.on_messages(messages, cancellation_token) for inner_message in response.inner_messages or []: yield inner_message + yield response.chat_message yield response async def run( @@ -56,7 +57,6 @@ async def run( task: str, *, cancellation_token: CancellationToken | None = None, - termination_condition: TerminationCondition | None = None, ) -> TaskResult: """Run the agent with the given task and return the result.""" if cancellation_token is None: @@ -80,10 +80,11 @@ async def run_stream( if cancellation_token is None: cancellation_token = CancellationToken() first_message = TextMessage(content=task, source="user") + yield first_message messages: List[InnerMessage | ChatMessage] = [first_message] async for message in self.on_messages_stream([first_message], cancellation_token): if isinstance(message, Response): yield TaskResult(messages=messages) else: messages.append(message) - yield message \ No newline at end of file + yield message diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index 375c4aa7f2c5..2e68c2b8118b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -4,7 +4,6 @@ from autogen_core.base import CancellationToken from ..messages import ChatMessage, InnerMessage -from ._termination import TerminationCondition @dataclass @@ -23,11 +22,10 @@ async def run( task: str, *, cancellation_token: CancellationToken | None = None, - termination_condition: TerminationCondition | None = None, ) -> TaskResult: """Run the task.""" ... - + def run_stream( self, task: str, diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py index 4028bb279887..e112a3b512ed 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_team.py @@ -1,6 +1,5 @@ from typing import Protocol - from ._task import TaskRunner diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 453ae486b6f2..78ec5159e369 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -76,19 +76,21 @@ async def run( cancellation_token: CancellationToken | None = None, termination_condition: TerminationCondition | None = None, ) -> TaskResult: - """Run the team and return the result. The base implementation uses + """Run the team and return the result. The base implementation uses :meth:`run_stream` to run the team and then returns the final result.""" - async for message in self.run_stream(task, cancellation_token=cancellation_token): + async for message in self.run_stream( + task, cancellation_token=cancellation_token, termination_condition=termination_condition + ): if isinstance(message, TaskResult): return message - assert False, "The stream should have returned the final result." - + raise AssertionError("The stream should have returned the final result.") async def run_stream( self, task: str, *, cancellation_token: CancellationToken | None = None, + termination_condition: TerminationCondition | None = None, ) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]: """Run the team and produces a stream of messages and the final result as the last item in the stream.""" @@ -131,7 +133,7 @@ async def run_stream( group_topic_type=group_topic_type, participant_topic_types=participant_topic_types, participant_descriptions=participant_descriptions, - termination_condition=self._termination_condition, + termination_condition=termination_condition or self._termination_condition, ), ) # Add subscriptions for the group chat manager. @@ -174,6 +176,7 @@ async def collect_output_messages( group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id) first_chat_message = TextMessage(content=task, source="user") output_messages.append(first_chat_message) + await output_message_queue.put(first_chat_message) await runtime.publish_message( GroupChatPublishEvent(agent_message=first_chat_message), topic_id=team_topic_id, @@ -184,6 +187,7 @@ async def collect_output_messages( async def stop_runtime() -> None: await runtime.stop_when_idle() await output_message_queue.put(None) + shutdown_task = asyncio.create_task(stop_runtime()) # Yield the messsages until the queue is empty. @@ -192,7 +196,7 @@ async def stop_runtime() -> None: if message is None: break yield message - + # Wait for the shutdown task to finish. await shutdown_task diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index c05faf399763..2bd0770fb1b9 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -39,12 +39,6 @@ async def handle_content_request(self, message: GroupChatRequestPublishEvent, ct # Pass the messages in the buffer to the delegate agent. response: Response | None = None async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token): - if not any(isinstance(msg, msg_type) for msg_type in self._agent.produced_message_types): - raise ValueError( - f"The agent {self._agent.name} produced an unexpected message type: {type(msg)}. " - f"Expected one of: {self._agent.produced_message_types}. " - f"Check the agent's produced_message_types property." - ) if isinstance(msg, Response): response = msg else: diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 9dee76539be4..8fca0b42f557 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -6,9 +6,9 @@ import pytest from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import AssistantAgent, Handoff +from autogen_agentchat.base import TaskResult from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages -from autogen_core.base import CancellationToken from autogen_core.components.tools import FunctionTool from autogen_ext.models import OpenAIChatCompletionClient from openai.resources.chat.completions import AsyncCompletions @@ -117,6 +117,16 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert isinstance(result.messages[2], ToolCallResultMessages) assert isinstance(result.messages[3], TextMessage) + # Test streaming. + mock._curr_index = 0 # pyright: ignore + index = 0 + async for message in tool_use_agent.run_stream("task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: @@ -160,8 +170,19 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: handoffs=[handoff], ) assert HandoffMessage in tool_use_agent.produced_message_types - response = await tool_use_agent.on_messages( - [TextMessage(content="task", source="user")], cancellation_token=CancellationToken() - ) - assert isinstance(response.chat_message, HandoffMessage) - assert response.chat_message.target == "agent2" + result = await tool_use_agent.run("task") + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert isinstance(result.messages[1], ToolCallMessage) + assert isinstance(result.messages[2], ToolCallResultMessages) + assert isinstance(result.messages[3], HandoffMessage) + + # Test streaming. + mock._curr_index = 0 # pyright: ignore + index = 0 + async for message in tool_use_agent.run_stream("task"): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index e6510c2fa17e..3367d8874fbc 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -12,7 +12,7 @@ CodeExecutorAgent, Handoff, ) -from autogen_agentchat.base import Response +from autogen_agentchat.base import Response, TaskResult from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.messages import ( ChatMessage, @@ -59,6 +59,9 @@ async def mock_create( self._curr_index += 1 return completion + def reset(self) -> None: + self._curr_index = 0 + class _EchoAgent(BaseChatAgent): def __init__(self, name: str, description: str) -> None: @@ -147,7 +150,8 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: ) team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent]) result = await team.run( - "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + "Write a program that prints 'Hello, world!'", + termination_condition=StopMessageTermination(), ) expected_messages = [ "Write a program that prints 'Hello, world!'", @@ -164,6 +168,18 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: # Assert that all expected messages are in the collected messages assert normalized_messages == expected_messages + # Test streaming. + mock.reset() + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: @@ -230,7 +246,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch echo_agent = _EchoAgent("echo_agent", description="echo agent") team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent]) result = await team.run( - "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + "Write a program that prints 'Hello, world!'", + termination_condition=StopMessageTermination(), ) assert len(result.messages) == 6 @@ -253,6 +270,19 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch assert context[2].content[0].call_id == "1" assert context[3].content == "Hello" + # Test streaming. + tool_use_agent._model_context.clear() # pyright: ignore + mock.reset() + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: @@ -320,7 +350,8 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: model_client=OpenAIChatCompletionClient(model=model, api_key=""), ) result = await team.run( - "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + "Write a program that prints 'Hello, world!'", + termination_condition=StopMessageTermination(), ) assert len(result.messages) == 6 assert result.messages[0].content == "Write a program that prints 'Hello, world!'" @@ -330,6 +361,19 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[4].source == "agent2" assert result.messages[5].source == "agent1" + # Test streaming. + mock.reset() + agent1._count = 0 # pyright: ignore + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None: @@ -356,7 +400,8 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) model_client=OpenAIChatCompletionClient(model=model, api_key=""), ) result = await team.run( - "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + "Write a program that prints 'Hello, world!'", + termination_condition=StopMessageTermination(), ) assert len(result.messages) == 5 assert result.messages[0].content == "Write a program that prints 'Hello, world!'" @@ -367,6 +412,19 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) # only one chat completion was called assert mock._curr_index == 1 # pyright: ignore + # Test streaming. + mock.reset() + agent1._count = 0 # pyright: ignore + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None: @@ -422,6 +480,18 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte assert result.messages[2].source == "agent2" assert result.messages[3].source == "agent1" + # Test streaming. + mock.reset() + index = 0 + async for message in team.run_stream( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + class _HandOffAgent(BaseChatAgent): def __init__(self, name: str, description: str, next_agent: str) -> None: @@ -446,8 +516,8 @@ async def test_swarm_handoff() -> None: second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent") third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent") - team = Swarm([second_agent, first_agent, third_agent]) - result = await team.run("task", termination_condition=MaxMessageTermination(6)) + team = Swarm([second_agent, first_agent, third_agent], termination_condition=MaxMessageTermination(6)) + result = await team.run("task") assert len(result.messages) == 6 assert result.messages[0].content == "task" assert result.messages[1].content == "Transferred to third_agent." @@ -456,6 +526,15 @@ async def test_swarm_handoff() -> None: assert result.messages[4].content == "Transferred to third_agent." assert result.messages[5].content == "Transferred to first_agent." + # Test streaming. + index = 0 + async for message in team.run_stream("task", termination_condition=MaxMessageTermination(6)): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + @pytest.mark.asyncio async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None: @@ -514,13 +593,13 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - mock = _MockChatCompletion(chat_completions) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) - agnet1 = AssistantAgent( + agent1 = AssistantAgent( "agent1", model_client=OpenAIChatCompletionClient(model=model, api_key=""), handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")], ) agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1") - team = Swarm([agnet1, agent2]) + team = Swarm([agent1, agent2]) result = await team.run("task", termination_condition=StopMessageTermination()) assert len(result.messages) == 7 assert result.messages[0].content == "task" @@ -530,3 +609,14 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - assert result.messages[4].content == "Transferred to agent1." assert result.messages[5].content == "Hello" assert result.messages[6].content == "TERMINATE" + + # Test streaming. + agent1._model_context.clear() # pyright: ignore + mock.reset() + index = 0 + async for message in team.run_stream("task", termination_condition=StopMessageTermination()): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 From 76cd615b8eae28cc946714251fa6818637459197 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 31 Oct 2024 13:33:19 -0700 Subject: [PATCH 5/6] Ready for review. --- .../agents/_assistant_agent.py | 37 ++++++++++---- .../agents/_base_chat_agent.py | 5 +- .../src/autogen_agentchat/base/_chat_agent.py | 4 +- .../src/autogen_agentchat/messages.py | 6 +-- .../_group_chat/_chat_agent_container.py | 4 ++ .../_group_chat/_round_robin_group_chat.py | 41 +++++++++++---- .../teams/_group_chat/_selector_group_chat.py | 50 ++++++++++++++++--- .../teams/_group_chat/_swarm_group_chat.py | 5 +- .../tests/test_assistant_agent.py | 6 +-- .../tests/test_group_chat.py | 12 +++-- 10 files changed, 125 insertions(+), 45 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 5414f782f022..06ed9195a0f2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from typing import Any, Awaitable, Callable, Dict, List, Sequence +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall @@ -27,7 +27,7 @@ StopMessage, TextMessage, ToolCallMessage, - ToolCallResultMessages, + ToolCallResultMessage, ) from ._base_chat_agent import BaseChatAgent @@ -138,7 +138,7 @@ class AssistantAgent(BaseChatAgent): The following example demonstrates how to create an assistant agent with - a model client and a tool, and generate a response to a simple task using the tool. + a model client and a tool, and generate a stream of messages for a task. .. code-block:: python @@ -154,7 +154,11 @@ async def get_current_time() -> str: model_client = OpenAIChatCompletionClient(model="gpt-4o") agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time]) - await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3)) + stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3)) + + async for message in stream: + print(message) + """ @@ -219,6 +223,14 @@ def produced_message_types(self) -> List[type[ChatMessage]]: return [TextMessage, StopMessage] async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: + async for message in self.on_messages_stream(messages, cancellation_token): + if isinstance(message, Response): + return message + raise AssertionError("The stream should have returned the final result.") + + async def on_messages_stream( + self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken + ) -> AsyncGenerator[InnerMessage | Response, None]: # Add messages to the model context. for msg in messages: if isinstance(msg, ResetMessage): @@ -243,6 +255,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name)) # Add the tool call message to the output. inner_messages.append(ToolCallMessage(content=result.content, source=self.name)) + yield ToolCallMessage(content=result.content, source=self.name) # Execute the tool calls. results = await asyncio.gather( @@ -250,7 +263,8 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: ) event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name)) self._model_context.append(FunctionExecutionResultMessage(content=results)) - inner_messages.append(ToolCallResultMessages(content=results, source=self.name)) + inner_messages.append(ToolCallResultMessage(content=results, source=self.name)) + yield ToolCallResultMessage(content=results, source=self.name) # Detect handoff requests. handoffs: List[Handoff] = [] @@ -261,12 +275,13 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: if len(handoffs) > 1: raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}") # Return the output messages to signal the handoff. - return Response( + yield Response( chat_message=HandoffMessage( content=handoffs[0].message, target=handoffs[0].target, source=self.name ), inner_messages=inner_messages, ) + return # Generate an inference result based on the current model context. result = await self._model_client.create( @@ -278,13 +293,13 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: # Detect stop request. request_stop = "terminate" in result.content.strip().lower() if request_stop: - return Response( + yield Response( chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages ) - - return Response( - chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages - ) + else: + yield Response( + chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages + ) async def _execute_tool_call( self, tool_call: FunctionCall, cancellation_token: CancellationToken diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 99f6f1a7e71a..cf146b0c10fb 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -42,14 +42,13 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken - ) -> AsyncGenerator[InnerMessage | ChatMessage | Response, None]: + ) -> AsyncGenerator[InnerMessage | Response, None]: """Handles incoming messages and returns a stream of messages and and the final item is the response. The base implementation in :class:`BaseChatAgent` simply calls :meth:`on_messages` and yields the messages in the response.""" response = await self.on_messages(messages, cancellation_token) for inner_message in response.inner_messages or []: yield inner_message - yield response.chat_message yield response async def run( @@ -84,6 +83,8 @@ async def run_stream( messages: List[InnerMessage | ChatMessage] = [first_message] async for message in self.on_messages_stream([first_message], cancellation_token): if isinstance(message, Response): + yield message.chat_message + messages.append(message.chat_message) yield TaskResult(messages=messages) else: messages.append(message) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index b6c76fa1f9be..ce73352daecc 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -46,7 +46,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken - ) -> AsyncGenerator[InnerMessage | ChatMessage | Response, None]: - """Handles incoming messages and returns a stream of messages and + ) -> AsyncGenerator[InnerMessage | Response, None]: + """Handles incoming messages and returns a stream of inner messages and and the final item is the response.""" ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index f206250e101e..51dbcca333d7 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -57,14 +57,14 @@ class ToolCallMessage(BaseMessage): """The tool calls.""" -class ToolCallResultMessages(BaseMessage): +class ToolCallResultMessage(BaseMessage): """A message signaling the results of tool calls.""" content: List[FunctionExecutionResult] """The tool call results.""" -InnerMessage = ToolCallMessage | ToolCallResultMessages +InnerMessage = ToolCallMessage | ToolCallResultMessage """Messages for intra-agent monologues.""" @@ -80,6 +80,6 @@ class ToolCallResultMessages(BaseMessage): "HandoffMessage", "ResetMessage", "ToolCallMessage", - "ToolCallResultMessages", + "ToolCallResultMessage", "ChatMessage", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index 2bd0770fb1b9..3fde3f6864b9 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -40,6 +40,10 @@ async def handle_content_request(self, message: GroupChatRequestPublishEvent, ct response: Response | None = None async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token): if isinstance(msg, Response): + await self.publish_message( + msg.chat_message, + topic_id=DefaultTopicId(type=self._output_topic_type), + ) response = msg else: # Publish the message to the output topic. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index e8f5f66533f2..cec47f6e1b1b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -61,24 +61,45 @@ class RoundRobinGroupChat(BaseGroupChat): .. code-block:: python - from autogen_agentchat.agents import ToolUseAssistantAgent - from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_agentchat.task import StopMessageTermination - assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...) + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + + async def get_weather(location: str) -> str: + return f"The weather in {location} is sunny." + + + assistant = AssistantAgent( + "Assistant", + model_client=model_client, + tools=[get_weather], + ) team = RoundRobinGroupChat([assistant]) - await team.run("What's the weather in New York?", termination_condition=StopMessageTermination()) + stream = team.run_stream("What's the weather in New York?", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) A team with multiple participants: .. code-block:: python - from autogen_agentchat.agents import CodingAssistantAgent, CodeExecutorAgent - from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_agentchat.task import StopMessageTermination + + model_client = OpenAIChatCompletionClient(model="gpt-4o") - coding_assistant = CodingAssistantAgent("Coding_Assistant", model_client=...) - executor_agent = CodeExecutorAgent("Code_Executor", code_executor=...) - team = RoundRobinGroupChat([coding_assistant, executor_agent]) - await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()) + agent1 = AssistantAgent("Assistant1", model_client=model_client) + agent2 = AssistantAgent("Assistant2", model_client=model_client) + team = RoundRobinGroupChat([agent1, agent2]) + stream = team.run_stream("Tell me some jokes.", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) """ diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 3cc489daa6b7..ed7694d38584 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -170,14 +170,48 @@ class SelectorGroupChat(BaseGroupChat): .. code-block:: python - from autogen_agentchat.agents import ToolUseAssistantAgent - from autogen_agentchat.teams import SelectorGroupChat, StopMessageTermination - - travel_advisor = ToolUseAssistantAgent("Travel_Advisor", model_client=..., registered_tools=...) - hotel_agent = ToolUseAssistantAgent("Hotel_Agent", model_client=..., registered_tools=...) - flight_agent = ToolUseAssistantAgent("Flight_Agent", model_client=..., registered_tools=...) - team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=...) - await team.run("Book a 3-day trip to new york.", termination_condition=StopMessageTermination()) + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import SelectorGroupChat + from autogen_agentchat.task import StopMessageTermination + + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + + async def lookup_hotel(location: str) -> str: + return f"Here are some hotels in {location}: hotel1, hotel2, hotel3." + + + async def lookup_flight(origin: str, destination: str) -> str: + return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3." + + + async def book_trip() -> str: + return "Your trip is booked!" + + + travel_advisor = AssistantAgent( + "Travel_Advisor", + model_client, + tools=[book_trip], + description="Helps with travel planning.", + ) + hotel_agent = AssistantAgent( + "Hotel_Agent", + model_client, + tools=[lookup_hotel], + description="Helps with hotel booking.", + ) + flight_agent = AssistantAgent( + "Flight_Agent", + model_client, + tools=[lookup_flight], + description="Helps with flight booking.", + ) + team = SelectorGroupChat([travel_advisor, hotel_agent, flight_agent], model_client=model_client) + stream = team.run_stream("Book a 3-day trip to new york.", termination_condition=StopMessageTermination()) + async for message in stream: + print(message) """ def __init__( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index 872f12e2ba31..0f4ec0e63a48 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -79,7 +79,10 @@ class Swarm(BaseGroupChat): ) team = Swarm([agent1, agent2]) - await team.run("What is bob's birthday?", termination_condition=MaxMessageTermination(3)) + + stream = team.run_stream("What is bob's birthday?", termination_condition=MaxMessageTermination(3)) + async for message in stream: + print(message) """ def __init__(self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None): diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 8fca0b42f557..4589f86860d3 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -8,7 +8,7 @@ from autogen_agentchat.agents import AssistantAgent, Handoff from autogen_agentchat.base import TaskResult from autogen_agentchat.logging import FileLogHandler -from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages +from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessage from autogen_core.components.tools import FunctionTool from autogen_ext.models import OpenAIChatCompletionClient from openai.resources.chat.completions import AsyncCompletions @@ -114,7 +114,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) assert isinstance(result.messages[1], ToolCallMessage) - assert isinstance(result.messages[2], ToolCallResultMessages) + assert isinstance(result.messages[2], ToolCallResultMessage) assert isinstance(result.messages[3], TextMessage) # Test streaming. @@ -174,7 +174,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) assert isinstance(result.messages[1], ToolCallMessage) - assert isinstance(result.messages[2], ToolCallResultMessages) + assert isinstance(result.messages[2], ToolCallResultMessage) assert isinstance(result.messages[3], HandoffMessage) # Test streaming. diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 3367d8874fbc..4e1485ce3094 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -20,7 +20,7 @@ StopMessage, TextMessage, ToolCallMessage, - ToolCallResultMessages, + ToolCallResultMessage, ) from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination from autogen_agentchat.teams import ( @@ -253,7 +253,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch assert len(result.messages) == 6 assert isinstance(result.messages[0], TextMessage) # task assert isinstance(result.messages[1], ToolCallMessage) # tool call - assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result + assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result assert isinstance(result.messages[3], TextMessage) # tool use agent response assert isinstance(result.messages[4], TextMessage) # echo agent response assert isinstance(result.messages[5], StopMessage) # tool use agent response @@ -528,7 +528,8 @@ async def test_swarm_handoff() -> None: # Test streaming. index = 0 - async for message in team.run_stream("task", termination_condition=MaxMessageTermination(6)): + stream = team.run_stream("task", termination_condition=MaxMessageTermination(6)) + async for message in stream: if isinstance(message, TaskResult): assert message == result else: @@ -604,7 +605,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - assert len(result.messages) == 7 assert result.messages[0].content == "task" assert isinstance(result.messages[1], ToolCallMessage) - assert isinstance(result.messages[2], ToolCallResultMessages) + assert isinstance(result.messages[2], ToolCallResultMessage) assert result.messages[3].content == "handoff to agent2" assert result.messages[4].content == "Transferred to agent1." assert result.messages[5].content == "Hello" @@ -614,7 +615,8 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - agent1._model_context.clear() # pyright: ignore mock.reset() index = 0 - async for message in team.run_stream("task", termination_condition=StopMessageTermination()): + stream = team.run_stream("task", termination_condition=StopMessageTermination()) + async for message in stream: if isinstance(message, TaskResult): assert message == result else: From 5847417e3b5a2b04f6ded071fc673f62122dae5f Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 31 Oct 2024 14:02:25 -0700 Subject: [PATCH 6/6] fix bug in handoff tool --- .../src/autogen_agentchat/agents/_assistant_agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 06ed9195a0f2..86a4f39952b8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -98,7 +98,11 @@ def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]: @property def handoff_tool(self) -> Tool: """Create a handoff tool from this handoff configuration.""" - return FunctionTool(lambda: self.message, name=self.name, description=self.description) + + def _handoff_tool() -> str: + return self.message + + return FunctionTool(_handoff_tool, name=self.name, description=self.description) class AssistantAgent(BaseChatAgent):