From 8139f7b0df0345d738a4176e88a9c28d88d52244 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Wed, 4 Dec 2024 02:26:25 +0530 Subject: [PATCH 01/19] feat: add support for list of messages as team task input --- .../teams/_group_chat/_base_group_chat.py | 19 ++++++-- .../tests/test_group_chat.py | 47 +++++++++++++++++++ 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index fdb79b1197f3..58f87e5a5f8c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -164,7 +164,7 @@ async def collect_output_messages( async def run( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | list[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the team and return the result. The base implementation uses @@ -172,7 +172,7 @@ async def run( Once the team is stopped, the termination condition is reset. Args: - task (str | ChatMessage | None): The task to run the team with. + task (str | ChatMessage | list[ChatMessage] | None): The task to run the team with. Can be a string, a single message, or a list of messages. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. Setting the cancellation token potentially put the team in an inconsistent state, and it may not reset the termination condition. @@ -263,7 +263,7 @@ async def main() -> None: async def run_stream( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | list[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the team and produces a stream of messages and the final result @@ -271,7 +271,7 @@ async def run_stream( team is stopped, the termination condition is reset. Args: - task (str | ChatMessage | None): The task to run the team with. + task (str | ChatMessage | list[ChatMessage] | None): The task to run the team with. Can be a string, a single message, or a list of messages. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. Setting the cancellation token potentially put the team in an inconsistent state, and it may not reset the termination condition. @@ -360,8 +360,17 @@ async def main() -> None: pass elif isinstance(task, str): first_chat_message = TextMessage(content=task, source="user") - elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): + elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): first_chat_message = task + elif isinstance(task, list): + if not task: + raise ValueError("Task list cannot be empty") + if not all(isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)) for msg in task): + raise ValueError("All messages in task list must be valid ChatMessage types") + first_chat_message = task[0] + # Queue remaining messages for processing + for msg in task[1:]: + await self._output_message_queue.put(msg) else: raise ValueError(f"Invalid task type: {type(task)}") diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 5c1d681fca07..b5eb7ff3f69d 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -942,3 +942,50 @@ async def test_swarm_with_handoff_termination() -> None: assert result.messages[1].content == "Transferred to second_agent." assert result.messages[2].content == "Transferred to third_agent." assert result.messages[3].content == "Transferred to non_existing_agent." + + +@pytest.mark.asyncio +async def test_round_robin_group_chat_with_message_list() -> None: + # Create a simple team with echo agents + agent1 = _EchoAgent("Agent1", "First agent") + agent2 = _EchoAgent("Agent2", "Second agent") + termination = MaxMessageTermination(4) # Stop after 4 messages + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + # Create a list of messages + messages = [ + TextMessage(content="Message 1", source="user"), + TextMessage(content="Message 2", source="user"), + TextMessage(content="Message 3", source="user"), + ] + + # Run the team with the message list + result = await team.run(task=messages) + + # Verify the messages were processed in order + assert len(result.messages) == 6 # All initial messages + echoes until termination + assert result.messages[0].content == "Message 2" # Second message from queue + assert result.messages[1].content == "Message 3" # Third message from queue + assert result.messages[2].content == "Message 1" # First message (processed first) + assert result.messages[3].content == "Message 1" # Echo from first agent + assert result.messages[4].content == "Message 1" # Echo from second agent + assert result.messages[5].content == "Message 1" # Echo from first agent (before termination) + assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4" + + # Test with streaming + await team.reset() + index = 0 + async for message in team.run_stream(task=messages): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + # Test with invalid message list + with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"): + await team.run(task=["not a message"]) + + # Test with empty message list + with pytest.raises(ValueError, match="Task list cannot be empty"): + await team.run(task=[]) From 424da3d1ab09e5bb8f7936148e3d2fe4e026b3b5 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Wed, 4 Dec 2024 16:03:19 +0530 Subject: [PATCH 02/19] feat: enhance task handling to support single and multiple messages in group chat --- .../src/autogen_agentchat/base/_task.py | 6 ++-- .../teams/_group_chat/_base_group_chat.py | 14 +++++---- .../_group_chat/_base_group_chat_manager.py | 30 ++++++++++++------- .../teams/_group_chat/_events.py | 6 ++-- 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index f617b3823451..d0d9aefe70a7 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import AsyncGenerator, Protocol, Sequence +from typing import AsyncGenerator, List, Protocol, Sequence from autogen_core import CancellationToken @@ -23,7 +23,7 @@ class TaskRunner(Protocol): async def run( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the task and return the result. @@ -36,7 +36,7 @@ async def run( def run_stream( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the task and produces a stream of messages and the final result diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 58f87e5a5f8c..8a9080caa163 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -149,7 +149,11 @@ async def collect_output_messages( if isinstance(message, GroupChatTermination): self._stop_reason = message.message.content return - await self._output_message_queue.put(message.message) + + # Handle single message or list of messages + messages = message.message if isinstance(message.message, List) else [message.message] + for msg in messages: + await self._output_message_queue.put(msg) await ClosureAgent.register_closure( runtime, @@ -164,7 +168,7 @@ async def collect_output_messages( async def run( self, *, - task: str | ChatMessage | list[ChatMessage] | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the team and return the result. The base implementation uses @@ -172,7 +176,7 @@ async def run( Once the team is stopped, the termination condition is reset. Args: - task (str | ChatMessage | list[ChatMessage] | None): The task to run the team with. Can be a string, a single message, or a list of messages. + task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. Setting the cancellation token potentially put the team in an inconsistent state, and it may not reset the termination condition. @@ -263,7 +267,7 @@ async def main() -> None: async def run_stream( self, *, - task: str | ChatMessage | list[ChatMessage] | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the team and produces a stream of messages and the final result @@ -271,7 +275,7 @@ async def run_stream( team is stopped, the termination condition is reset. Args: - task (str | ChatMessage | list[ChatMessage] | None): The task to run the team with. Can be a string, a single message, or a list of messages. + task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. Setting the cancellation token potentially put the team in an inconsistent state, and it may not reset the termination condition. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index aefe4f8d49d8..4b2294da44a9 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -74,20 +74,28 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self.validate_group_state(message.message) if message.message is not None: - # Log the start message. - await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) + messages_to_process = [message.message] if isinstance(message.message, ChatMessage) else message.message - # Relay the start message to the participants. - await self.publish_message( - message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token - ) + # Log and relay each message + for msg in messages_to_process: + # Log the message + await self.publish_message( + GroupChatStart(message=msg), topic_id=DefaultTopicId(type=self._output_topic_type) + ) + + # Relay the message to participants + await self.publish_message( + GroupChatStart(message=msg), + topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=ctx.cancellation_token, + ) - # Append the user message to the message thread. - self._message_thread.append(message.message) + # Append to message thread + self._message_thread.append(msg) - # Check if the conversation should be terminated. + # Check termination condition after processing all messages if self._termination_condition is not None: - stop_message = await self._termination_condition([message.message]) + stop_message = await self._termination_condition(messages_to_process) if stop_message is not None: await self.publish_message( GroupChatTermination(message=stop_message), @@ -97,7 +105,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self._termination_condition.reset() return - # Select a speaker to start the conversation. + # Select a speaker to start/continue the conversation speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) # Link the select speaker future to the cancellation token. ctx.cancellation_token.link_future(speaker_topic_type_future) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py index 4ae4d892cace..fa0f556547b2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py @@ -1,3 +1,5 @@ +from typing import List + from pydantic import BaseModel from ...base import Response @@ -7,8 +9,8 @@ class GroupChatStart(BaseModel): """A request to start a group chat.""" - message: ChatMessage | None = None - """An optional user message to start the group chat.""" + message: ChatMessage | List[ChatMessage] | None = None + """An optional user message or list of messages to start the group chat.""" class GroupChatAgentResponse(BaseModel): From 9a7e0a64724c9efe14737cb8dd1c459cb23b95ae Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Thu, 5 Dec 2024 10:54:17 +0530 Subject: [PATCH 03/19] Update python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../teams/_group_chat/_base_group_chat_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 4b2294da44a9..507f812cd5ac 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -95,7 +95,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No # Check termination condition after processing all messages if self._termination_condition is not None: - stop_message = await self._termination_condition(messages_to_process) + stop_message = await self._termination_condition([msg]) if stop_message is not None: await self.publish_message( GroupChatTermination(message=stop_message), From 71f3db58cf43041bd499071baa25457e8db91175 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Sat, 7 Dec 2024 18:21:24 +0530 Subject: [PATCH 04/19] Refactor message processing in handle_start to check for ChatMessage type using 'type' attribute --- .../teams/_group_chat/_base_group_chat_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 507f812cd5ac..e674dadb5e1a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -74,7 +74,8 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self.validate_group_state(message.message) if message.message is not None: - messages_to_process = [message.message] if isinstance(message.message, ChatMessage) else message.message + # Check if message is a ChatMessage by checking for the discriminator field 'type' + messages_to_process = [message.message] if hasattr(message.message, 'type') else message.message # Log and relay each message for msg in messages_to_process: @@ -95,7 +96,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No # Check termination condition after processing all messages if self._termination_condition is not None: - stop_message = await self._termination_condition([msg]) + stop_message = await self._termination_condition(messages_to_process) if stop_message is not None: await self.publish_message( GroupChatTermination(message=stop_message), From abaf4b9c5ff0e17ee9567bd529cd0d96390d5614 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Sat, 7 Dec 2024 20:40:22 +0530 Subject: [PATCH 05/19] update _base_group_chat_manager.py, _chat_agent_container.py and _magentic_one_orchestrator.py --- .../_group_chat/_base_group_chat_manager.py | 28 +++++++++---------- .../_group_chat/_chat_agent_container.py | 4 ++- .../_magentic_one_orchestrator.py | 5 +++- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index e674dadb5e1a..2f59df3dc76e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -75,24 +75,22 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No if message.message is not None: # Check if message is a ChatMessage by checking for the discriminator field 'type' - messages_to_process = [message.message] if hasattr(message.message, 'type') else message.message + messages_to_process = [message.message] if hasattr(message.message, "type") else message.message - # Log and relay each message - for msg in messages_to_process: - # Log the message - await self.publish_message( - GroupChatStart(message=msg), topic_id=DefaultTopicId(type=self._output_topic_type) - ) + # Log all messages at once + await self.publish_message( + GroupChatStart(message=messages_to_process), topic_id=DefaultTopicId(type=self._output_topic_type) + ) - # Relay the message to participants - await self.publish_message( - GroupChatStart(message=msg), - topic_id=DefaultTopicId(type=self._group_topic_type), - cancellation_token=ctx.cancellation_token, - ) + # Relay all messages at once to participants + await self.publish_message( + GroupChatStart(message=messages_to_process), + topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=ctx.cancellation_token, + ) - # Append to message thread - self._message_thread.append(msg) + # Append all messages to thread + self._message_thread.extend(messages_to_process) # Check termination condition after processing all messages if self._termination_condition is not None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index fdf5428b3b5c..806417eeb234 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -31,7 +31,9 @@ def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAg async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: """Handle a start event by appending the content to the buffer.""" if message.message is not None: - self._message_buffer.append(message.message) + # Handle single message or list of messages + messages = [message.message] if hasattr(message.message, "type") else message.message + self._message_buffer.extend(messages) @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 11b143e28c0b..1e9fcdcb3401 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -127,7 +127,10 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No # Outer Loop for first time # Create the initial task ledger ################################# - self._task = self._content_to_str(message.message.content) + # Handle single message or list of messages + messages = [message.message] if hasattr(message.message, "type") else message.message + # Combine all message contents for task + self._task = " ".join([self._content_to_str(msg.content) for msg in messages]) planning_conversation: List[LLMMessage] = [] # 1. GATHER FACTS From c08b8cf25d20bf90ab7670974f58a0a7f92a9eec Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Sat, 7 Dec 2024 23:57:43 +0530 Subject: [PATCH 06/19] Refactor GroupChatStart to support multiple messages; update related handling in _base_group_chat_manager, _chat_agent_container, and _magentic_one_orchestrator. --- .../teams/_group_chat/_base_group_chat.py | 25 ++++++++++++------- .../_group_chat/_base_group_chat_manager.py | 17 ++++++------- .../_group_chat/_chat_agent_container.py | 6 ++--- .../teams/_group_chat/_events.py | 4 +-- .../_magentic_one_orchestrator.py | 8 +++--- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index d5169742e899..3856ea7fcf0a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -146,15 +146,18 @@ async def collect_output_messages( message: GroupChatStart | GroupChatMessage | GroupChatTermination, ctx: MessageContext, ) -> None: - event_logger.info(message.message) - if isinstance(message, GroupChatTermination): + """Collect output messages from the group chat.""" + if isinstance(message, GroupChatStart): + if message.messages is not None: + for msg in message.messages: + event_logger.info(msg) + await self._output_message_queue.put(msg) + elif isinstance(message, GroupChatMessage): + event_logger.info(message.message) + await self._output_message_queue.put(message.message) + elif isinstance(message, GroupChatTermination): + event_logger.info(message.message) self._stop_reason = message.message.content - return - - # Handle single message or list of messages - messages = message.message if isinstance(message.message, List) else [message.message] - for msg in messages: - await self._output_message_queue.put(msg) await ClosureAgent.register_closure( runtime, @@ -401,8 +404,12 @@ async def stop_runtime() -> None: # Run the team by sending the start message to the group chat manager. # The group chat manager will start the group chat by relaying the message to the participants # and the closure agent. + if first_chat_message is not None: + messages = [first_chat_message] + else: + messages = None await self._runtime.send_message( - GroupChatStart(message=first_chat_message), + GroupChatStart(messages=messages), recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), cancellation_token=cancellation_token, ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 2f59df3dc76e..2726b43673ff 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -70,31 +70,28 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No # Stop the group chat. return - # Validate the group state given the start message. - await self.validate_group_state(message.message) - - if message.message is not None: - # Check if message is a ChatMessage by checking for the discriminator field 'type' - messages_to_process = [message.message] if hasattr(message.message, "type") else message.message + # Validate the group state given the start messages + await self.validate_group_state(message.messages[0] if message.messages else None) + if message.messages is not None: # Log all messages at once await self.publish_message( - GroupChatStart(message=messages_to_process), topic_id=DefaultTopicId(type=self._output_topic_type) + GroupChatStart(messages=message.messages), topic_id=DefaultTopicId(type=self._output_topic_type) ) # Relay all messages at once to participants await self.publish_message( - GroupChatStart(message=messages_to_process), + GroupChatStart(messages=message.messages), topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token, ) # Append all messages to thread - self._message_thread.extend(messages_to_process) + self._message_thread.extend(message.messages) # Check termination condition after processing all messages if self._termination_condition is not None: - stop_message = await self._termination_condition(messages_to_process) + stop_message = await self._termination_condition(message.messages) if stop_message is not None: await self.publish_message( GroupChatTermination(message=stop_message), diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index 806417eeb234..f01465d4c3d5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -30,10 +30,8 @@ def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAg @event async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: """Handle a start event by appending the content to the buffer.""" - if message.message is not None: - # Handle single message or list of messages - messages = [message.message] if hasattr(message.message, "type") else message.message - self._message_buffer.extend(messages) + if message.messages is not None: + self._message_buffer.extend(message.messages) @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py index fa0f556547b2..ed325fcb5159 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py @@ -9,8 +9,8 @@ class GroupChatStart(BaseModel): """A request to start a group chat.""" - message: ChatMessage | List[ChatMessage] | None = None - """An optional user message or list of messages to start the group chat.""" + messages: List[ChatMessage] | None = None + """An optional list of messages to start the group chat.""" class GroupChatAgentResponse(BaseModel): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 1e9fcdcb3401..18a9ca461e61 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -117,20 +117,18 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No ) # Stop the group chat. return - assert message is not None and message.message is not None + assert message is not None and message.messages is not None # Validate the group state given the start message. - await self.validate_group_state(message.message) + await self.validate_group_state(message.messages[0]) # Log the start message. await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) # Outer Loop for first time # Create the initial task ledger ################################# - # Handle single message or list of messages - messages = [message.message] if hasattr(message.message, "type") else message.message # Combine all message contents for task - self._task = " ".join([self._content_to_str(msg.content) for msg in messages]) + self._task = " ".join([self._content_to_str(msg.content) for msg in message.messages]) planning_conversation: List[LLMMessage] = [] # 1. GATHER FACTS From cfe3aafbb0cfe284f7eaa3c8251aa9a305bae8c3 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Sun, 8 Dec 2024 17:43:19 +0530 Subject: [PATCH 07/19] Enhance task handling in _base_chat_agent.py to support a list of ChatMessages --- .../autogen_agentchat/agents/_base_chat_agent.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 5b2aed4860c1..7b8a2aecb07a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -55,7 +55,7 @@ async def on_messages_stream( async def run( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the agent with the given task and return the result.""" @@ -69,7 +69,7 @@ async def run( text_msg = TextMessage(content=task, source="user") input_messages.append(text_msg) output_messages.append(text_msg) - elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): + elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): input_messages.append(task) output_messages.append(task) else: @@ -83,7 +83,7 @@ async def run( async def run_stream( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the agent with the given task and return a stream of messages @@ -99,7 +99,15 @@ async def run_stream( input_messages.append(text_msg) output_messages.append(text_msg) yield text_msg - elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): + elif isinstance(task, list): + for msg in task: + if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + input_messages.append(msg) + output_messages.append(msg) + yield msg + else: + raise ValueError(f"Invalid message type in list: {type(msg)}") + elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): input_messages.append(task) output_messages.append(task) yield task From ac011b8053641375dc5f5cffa732237e728b62ee Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Sun, 8 Dec 2024 18:17:39 +0530 Subject: [PATCH 08/19] Refactor validate_group_state method to accept a list of ChatMessages in _base_group_chat_manager and _swarm_group_chat --- .../_group_chat/_base_group_chat_manager.py | 11 ++++++--- .../teams/_group_chat/_swarm_group_chat.py | 23 +++++++++++-------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 2726b43673ff..84725cecdd65 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -71,7 +71,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No return # Validate the group state given the start messages - await self.validate_group_state(message.messages[0] if message.messages else None) + await self.validate_group_state(message.messages) if message.messages is not None: # Log all messages at once @@ -170,8 +170,13 @@ async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> No await self.reset() @abstractmethod - async def validate_group_state(self, message: ChatMessage | None) -> None: - """Validate the state of the group chat given the start message. This is executed when the group chat manager receives a GroupChatStart event.""" + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: + """Validate the state of the group chat given the start messages. + This is executed when the group chat manager receives a GroupChatStart event. + + Args: + messages: A list of chat messages to validate, or None if no messages are provided. + """ ... @abstractmethod diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index 10574e0a9fa6..ae58b3b07934 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -29,16 +29,19 @@ def __init__( ) self._current_speaker = participant_topic_types[0] - async def validate_group_state(self, message: ChatMessage | None) -> None: - """Validate the start message for the group chat.""" - # Check if the start message is a handoff message. - if isinstance(message, HandoffMessage): - if message.target not in self._participant_topic_types: - raise ValueError( - f"The target {message.target} is not one of the participants {self._participant_topic_types}. " - "If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target." - ) - return + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: + """Validate the start messages for the group chat.""" + # Check if any of the start messages is a handoff message. + if messages: + for message in messages: + if isinstance(message, HandoffMessage): + if message.target not in self._participant_topic_types: + raise ValueError( + f"The target {message.target} is not one of the participants {self._participant_topic_types}. " + "If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target." + ) + return + # Check if there is a handoff message in the thread that is not targeting a valid participant. for existing_message in reversed(self._message_thread): if isinstance(existing_message, HandoffMessage): From 376f337082aa88f1976e2e7607cb5ab58ad15d8a Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Mon, 9 Dec 2024 13:04:49 +0530 Subject: [PATCH 09/19] feat: add list input support for chat messages with unit tests --- .../agents/_base_chat_agent.py | 21 ++- .../_magentic_one_orchestrator.py | 1 - .../tests/test_assistant_agent.py | 128 ++++++++++++++++-- 3 files changed, 131 insertions(+), 19 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 7b8a2aecb07a..b4e58546c05a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -4,7 +4,14 @@ from autogen_core import CancellationToken from ..base import ChatAgent, Response, TaskResult -from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage +from ..messages import ( + AgentMessage, + ChatMessage, + HandoffMessage, + MultiModalMessage, + StopMessage, + TextMessage, +) from ..state import BaseState @@ -45,8 +52,9 @@ async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: """Handles incoming messages and returns a stream of messages and - and the final item is the response. The base implementation in :class:`BaseChatAgent` - simply calls :meth:`on_messages` and yields the messages in the response.""" + and the final item is the response. The base implementation in + :class:`BaseChatAgent` simply calls :meth:`on_messages` and yields + the messages in the response.""" response = await self.on_messages(messages, cancellation_token) for inner_message in response.inner_messages or []: yield inner_message @@ -69,6 +77,13 @@ async def run( text_msg = TextMessage(content=task, source="user") input_messages.append(text_msg) output_messages.append(text_msg) + elif isinstance(task, list): + for msg in task: + if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + input_messages.append(msg) + output_messages.append(msg) + else: + raise ValueError(f"Invalid message type in list: {type(msg)}") elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): input_messages.append(task) output_messages.append(task) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 249375a32869..3d6431c2d813 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -22,7 +22,6 @@ ToolCallMessage, ToolCallResultMessage, ) - from ....state import MagenticOneOrchestratorState from .._base_group_chat_manager import BaseGroupChatManager from .._events import ( diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index c09a92a28d3d..f0ecb89a92f5 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -8,6 +8,7 @@ from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.base import Handoff, TaskResult from autogen_agentchat.messages import ( + ChatMessage, HandoffMessage, MultiModalMessage, TextMessage, @@ -21,7 +22,10 @@ from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_message import ChatCompletionMessage -from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) from openai.types.completion_usage import CompletionUsage from utils import FileLogHandler @@ -33,14 +37,14 @@ class _MockChatCompletion: def __init__(self, chat_completions: List[ChatCompletion]) -> None: self._saved_chat_completions = chat_completions - self._curr_index = 0 + self.curr_index = 0 async def mock_create( self, *args: Any, **kwargs: Any ) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: await asyncio.sleep(0.1) - completion = self._saved_chat_completions[self._curr_index] - self._curr_index += 1 + completion = self._saved_chat_completions[self.curr_index] + self.curr_index += 1 return completion @@ -90,7 +94,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: ChatCompletion( id="id2", choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")) + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="Hello", role="assistant"), + ) ], created=0, model=model, @@ -101,7 +109,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: id="id2", choices=[ Choice( - finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant") + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="TERMINATE", role="assistant"), ) ], created=0, @@ -115,7 +125,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: agent = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], ) result = await agent.run(task="task") assert len(result.messages) == 4 @@ -133,7 +147,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[3].models_usage.prompt_tokens == 10 # Test streaming. - mock._curr_index = 0 # pyright: ignore + mock.curr_index = 0 # pyright: ignore index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): @@ -147,7 +161,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: agent2 = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], ) await agent2.load_state(state) state2 = await agent2.save_state() @@ -192,7 +210,11 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: tool_use_agent = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], handoffs=[handoff], ) assert HandoffMessage in tool_use_agent.produced_message_types @@ -212,7 +234,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[3].models_usage is None # Test streaming. - mock._curr_index = 0 # pyright: ignore + mock.curr_index = 0 # pyright: ignore index = 0 async for message in tool_use_agent.run_stream(task="task"): if isinstance(message, TaskResult): @@ -229,7 +251,11 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: ChatCompletion( id="id2", choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")) + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="Hello", role="assistant"), + ) ], created=0, model=model, @@ -239,7 +265,10 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: ] mock = _MockChatCompletion(chat_completions) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) - agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key="")) + agent = AssistantAgent( + name="assistant", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + ) # Generate a random base64 image. img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) @@ -250,14 +279,24 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: async def test_invalid_model_capabilities() -> None: model = "random-model" model_client = OpenAIChatCompletionClient( - model=model, api_key="", model_capabilities={"vision": False, "function_calling": False, "json_output": False} + model=model, + api_key="", + model_capabilities={ + "vision": False, + "function_calling": False, + "json_output": False, + }, ) with pytest.raises(ValueError): agent = AssistantAgent( name="assistant", model_client=model_client, - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], ) with pytest.raises(ValueError): @@ -268,3 +307,62 @@ async def test_invalid_model_capabilities() -> None: # Generate a random base64 image. img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) + + +@pytest.mark.asyncio +async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="Response to message 1", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + agent = AssistantAgent( + "test_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + ) + + # Create a list of chat messages + messages: List[ChatMessage] = [ + TextMessage(content="Message 1", source="user"), + TextMessage(content="Message 2", source="user"), + ] + + # Test run method with list of messages + result = await agent.run(task=messages) + assert len(result.messages) == 3 # 2 input messages + 1 response message + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].content == "Message 1" + assert result.messages[0].source == "user" + assert isinstance(result.messages[1], TextMessage) + assert result.messages[1].content == "Message 2" + assert result.messages[1].source == "user" + assert isinstance(result.messages[2], TextMessage) + assert result.messages[2].content == "Response to message 1" + assert result.messages[2].source == "test_agent" + assert result.messages[2].models_usage is not None + assert result.messages[2].models_usage.completion_tokens == 5 + assert result.messages[2].models_usage.prompt_tokens == 10 + + # Test run_stream method with list of messages + mock.curr_index = 0 # Reset mock index using public attribute + index = 0 + async for message in agent.run_stream(task=messages): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 From 966674045817dd95936c3998d31104a001e5ca18 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Tue, 10 Dec 2024 21:28:06 +0530 Subject: [PATCH 10/19] fix: type check for Annotated types --- .../autogen_agentchat/agents/_base_chat_agent.py | 15 ++++++++++----- .../teams/_group_chat/_base_group_chat.py | 14 +++++++++++--- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index b4e58546c05a..b17bf18c309e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,5 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, List, Mapping, Sequence +from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args, get_origin + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated from autogen_core import CancellationToken @@ -79,12 +84,12 @@ async def run( output_messages.append(text_msg) elif isinstance(task, list): for msg in task: - if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + if get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__: input_messages.append(msg) output_messages.append(msg) else: raise ValueError(f"Invalid message type in list: {type(msg)}") - elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: input_messages.append(task) output_messages.append(task) else: @@ -116,13 +121,13 @@ async def run_stream( yield text_msg elif isinstance(task, list): for msg in task: - if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + if get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__: input_messages.append(msg) output_messages.append(msg) yield msg else: raise ValueError(f"Invalid message type in list: {type(msg)}") - elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: input_messages.append(task) output_messages.append(task) yield task diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 3856ea7fcf0a..2de743c6297d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -2,7 +2,12 @@ import logging import uuid from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Callable, List, Mapping +from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args, get_origin + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated from autogen_core import ( AgentId, @@ -368,12 +373,15 @@ async def main() -> None: pass elif isinstance(task, str): first_chat_message = TextMessage(content=task, source="user") - elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: first_chat_message = task elif isinstance(task, list): if not task: raise ValueError("Task list cannot be empty") - if not all(isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)) for msg in task): + if not all( + get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__ + for msg in task + ): raise ValueError("All messages in task list must be valid ChatMessage types") first_chat_message = task[0] # Queue remaining messages for processing From 46381ec691d3c6fe579ca477f819db5b261b2ef2 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Tue, 10 Dec 2024 22:24:24 +0530 Subject: [PATCH 11/19] fix(test): Update mock chat completion message content to match expected result in test_run_with_tools --- .../autogen-agentchat/tests/test_assistant_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 5d0626070d3f..c132e3a4862c 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -97,7 +97,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(content="Hello", role="assistant"), + message=ChatCompletionMessage(content="pass", role="assistant"), ) ], created=0, @@ -147,14 +147,14 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[3].models_usage is None # Test streaming. - mock._curr_index = 0 # pyright: ignore + mock.curr_index = 0 # Reset the mock index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): assert message == result else: assert message == result.messages[index] - index += 1 + index += 1 # Test state saving and loading. state = await agent.save_state() From cd7a24537ff903552a974b88db22c37702b60c98 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Tue, 10 Dec 2024 23:18:45 +0530 Subject: [PATCH 12/19] fix: ChatMessage to List[ChatMessage] --- .../teams/_group_chat/_base_group_chat.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 5f149414524a..c00387eab732 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -24,7 +24,7 @@ from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TaskResult, Team, TerminationCondition -from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage +from ...messages import AgentMessage, ChatMessage, TextMessage from ...state import TeamState from ._chat_agent_container import ChatAgentContainer from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination @@ -368,7 +368,7 @@ async def main() -> None: """ # Create the first chat message if the task is a string or a chat message. - first_chat_message: ChatMessage | None = None + first_chat_message: ChatMessage | List[ChatMessage] | None = None if task is None: pass elif isinstance(task, str): @@ -378,11 +378,10 @@ async def main() -> None: elif isinstance(task, list): if not task: raise ValueError("Task list cannot be empty") - if not all( - get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__ - for msg in task - ): - raise ValueError("All messages in task list must be valid ChatMessage types") + # Validate all messages in the list + for msg in task: + if not (isinstance(msg, TextMessage) or isinstance(msg, get_args(ChatMessage)[0].__args__)): + raise ValueError("All messages in task list must be valid ChatMessage types") first_chat_message = task[0] # Queue remaining messages for processing for msg in task[1:]: @@ -413,7 +412,10 @@ async def stop_runtime() -> None: # The group chat manager will start the group chat by relaying the message to the participants # and the closure agent. if first_chat_message is not None: - messages = [first_chat_message] + if isinstance(first_chat_message, list): + messages = first_chat_message + else: + messages = [first_chat_message] else: messages = None await self._runtime.send_message( From 0f875d1246095215e85e7f0739a4bd0f3df77e37 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Wed, 11 Dec 2024 01:32:15 +0530 Subject: [PATCH 13/19] fixed mypy and pyright issues related to type checks --- .../agents/_base_chat_agent.py | 18 +++++------------- .../teams/_group_chat/_base_group_chat.py | 18 +++++++----------- .../_magentic_one_orchestrator.py | 4 ++-- .../_group_chat/_round_robin_group_chat.py | 2 +- .../teams/_group_chat/_selector_group_chat.py | 2 +- .../autogen-agentchat/tests/test_group_chat.py | 4 ++-- 6 files changed, 18 insertions(+), 30 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index b17bf18c309e..c06fb8d6db53 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,10 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args, get_origin - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated +from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args from autogen_core import CancellationToken @@ -12,9 +7,6 @@ from ..messages import ( AgentMessage, ChatMessage, - HandoffMessage, - MultiModalMessage, - StopMessage, TextMessage, ) from ..state import BaseState @@ -84,12 +76,12 @@ async def run( output_messages.append(text_msg) elif isinstance(task, list): for msg in task: - if get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__: + if isinstance(msg, get_args(ChatMessage)[0]): input_messages.append(msg) output_messages.append(msg) else: raise ValueError(f"Invalid message type in list: {type(msg)}") - elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: + elif isinstance(task, get_args(ChatMessage)[0]): input_messages.append(task) output_messages.append(task) else: @@ -121,13 +113,13 @@ async def run_stream( yield text_msg elif isinstance(task, list): for msg in task: - if get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__: + if isinstance(msg, get_args(ChatMessage)[0]): input_messages.append(msg) output_messages.append(msg) yield msg else: raise ValueError(f"Invalid message type in list: {type(msg)}") - elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: + elif isinstance(task, get_args(ChatMessage)[0]): input_messages.append(task) output_messages.append(task) yield task diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index c00387eab732..15baf17d2db3 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -2,12 +2,7 @@ import logging import uuid from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args, get_origin - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated +from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args from autogen_core import ( AgentId, @@ -373,15 +368,16 @@ async def main() -> None: pass elif isinstance(task, str): first_chat_message = TextMessage(content=task, source="user") - elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: + elif isinstance(task, get_args(ChatMessage)[0]): first_chat_message = task elif isinstance(task, list): if not task: raise ValueError("Task list cannot be empty") - # Validate all messages in the list - for msg in task: - if not (isinstance(msg, TextMessage) or isinstance(msg, get_args(ChatMessage)[0].__args__)): - raise ValueError("All messages in task list must be valid ChatMessage types") + if not all( + isinstance(msg, get_args(ChatMessage)[0]) + for msg in task + ): + raise ValueError("All messages in task list must be valid ChatMessage types") first_chat_message = task[0] # Queue remaining messages for processing for msg in task[1:]: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 389aa6fc0149..fbd9b3b0816e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -129,7 +129,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No assert message is not None and message.messages is not None # Validate the group state given the start message. - await self.validate_group_state(message.messages[0]) + await self.validate_group_state([message.messages[0]]) # Log the start message. await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) @@ -185,7 +185,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess return await self._orchestrate_step(ctx.cancellation_token) - async def validate_group_state(self, message: ChatMessage | None) -> None: + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: pass async def save_state(self) -> Mapping[str, Any]: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index 8330c89a3f12..3e17943b90b2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -29,7 +29,7 @@ def __init__( ) self._next_speaker_index = 0 - async def validate_group_state(self, message: ChatMessage | None) -> None: + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: pass async def reset(self) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index a42b6fc2bfbb..5f161d0c6858 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -54,7 +54,7 @@ def __init__( self._allow_repeated_speaker = allow_repeated_speaker self._selector_func = selector_func - async def validate_group_state(self, message: ChatMessage | None) -> None: + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: pass async def reset(self) -> None: diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 23e75937256e..a32dbb006498 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -1036,7 +1036,7 @@ async def test_round_robin_group_chat_with_message_list() -> None: team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) # Create a list of messages - messages = [ + messages: List[ChatMessage] = [ TextMessage(content="Message 1", source="user"), TextMessage(content="Message 2", source="user"), TextMessage(content="Message 3", source="user"), @@ -1067,7 +1067,7 @@ async def test_round_robin_group_chat_with_message_list() -> None: # Test with invalid message list with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"): - await team.run(task=["not a message"]) + await team.run(task=["not a message"]) # type: ignore[list-item] # intentionally testing invalid input # Test with empty message list with pytest.raises(ValueError, match="Task list cannot be empty"): From cbe21afca13d163a032d539d6184d50220186aae Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Wed, 11 Dec 2024 09:30:07 +0530 Subject: [PATCH 14/19] fix: formatting of _base_group_chat --- .../autogen_agentchat/teams/_group_chat/_base_group_chat.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 15baf17d2db3..f00865d77c28 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -373,10 +373,7 @@ async def main() -> None: elif isinstance(task, list): if not task: raise ValueError("Task list cannot be empty") - if not all( - isinstance(msg, get_args(ChatMessage)[0]) - for msg in task - ): + if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task): raise ValueError("All messages in task list must be valid ChatMessage types") first_chat_message = task[0] # Queue remaining messages for processing From 55feb5f8925962eaecf933ec3312e6a82b68385f Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Thu, 12 Dec 2024 18:16:09 +0530 Subject: [PATCH 15/19] refactor: improve message handling in base group chat --- .../teams/_group_chat/_base_group_chat.py | 24 +++++++------------ .../tests/test_group_chat.py | 12 ++++------ 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index f00865d77c28..21c0126a774f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -362,23 +362,20 @@ async def main() -> None: """ - # Create the first chat message if the task is a string or a chat message. - first_chat_message: ChatMessage | List[ChatMessage] | None = None + # Create the messages list if the task is a string or a chat message. + messages: List[ChatMessage] | None = None if task is None: pass elif isinstance(task, str): - first_chat_message = TextMessage(content=task, source="user") + messages = [TextMessage(content=task, source="user")] elif isinstance(task, get_args(ChatMessage)[0]): - first_chat_message = task + messages = [task] # type: ignore elif isinstance(task, list): if not task: raise ValueError("Task list cannot be empty") if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task): raise ValueError("All messages in task list must be valid ChatMessage types") - first_chat_message = task[0] - # Queue remaining messages for processing - for msg in task[1:]: - await self._output_message_queue.put(msg) + messages = task else: raise ValueError(f"Invalid task type: {type(task)}") @@ -404,13 +401,10 @@ async def stop_runtime() -> None: # Run the team by sending the start message to the group chat manager. # The group chat manager will start the group chat by relaying the message to the participants # and the closure agent. - if first_chat_message is not None: - if isinstance(first_chat_message, list): - messages = first_chat_message - else: - messages = [first_chat_message] - else: - messages = None + if messages is not None: + if not isinstance(messages, list): + messages = [messages] + await self._runtime.send_message( GroupChatStart(messages=messages), recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index a32dbb006498..6a3c91b805c3 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -1046,13 +1046,11 @@ async def test_round_robin_group_chat_with_message_list() -> None: result = await team.run(task=messages) # Verify the messages were processed in order - assert len(result.messages) == 6 # All initial messages + echoes until termination - assert result.messages[0].content == "Message 2" # Second message from queue - assert result.messages[1].content == "Message 3" # Third message from queue - assert result.messages[2].content == "Message 1" # First message (processed first) + assert len(result.messages) == 4 # Initial messages + echo until termination + assert result.messages[0].content == "Message 1" # First message + assert result.messages[1].content == "Message 2" # Second message + assert result.messages[2].content == "Message 3" # Third message assert result.messages[3].content == "Message 1" # Echo from first agent - assert result.messages[4].content == "Message 1" # Echo from second agent - assert result.messages[5].content == "Message 1" # Echo from first agent (before termination) assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4" # Test with streaming @@ -1067,7 +1065,7 @@ async def test_round_robin_group_chat_with_message_list() -> None: # Test with invalid message list with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"): - await team.run(task=["not a message"]) # type: ignore[list-item] # intentionally testing invalid input + await team.run(task=["not a message"]) # type: ignore[list-item, arg-type] # intentionally testing invalid input # Test with empty message list with pytest.raises(ValueError, match="Task list cannot be empty"): From b5c0d1af7081cf6aae07e8c4f30bc613467ead86 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Thu, 12 Dec 2024 18:59:20 +0530 Subject: [PATCH 16/19] refactor: streamline task validation and message handling in base group chat --- .../autogen_agentchat/teams/_group_chat/_base_group_chat.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 21c0126a774f..7a6496acbe87 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -376,8 +376,6 @@ async def main() -> None: if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task): raise ValueError("All messages in task list must be valid ChatMessage types") messages = task - else: - raise ValueError(f"Invalid task type: {type(task)}") if self._is_running: raise ValueError("The team is already running, it cannot run again until it is stopped.") @@ -401,10 +399,6 @@ async def stop_runtime() -> None: # Run the team by sending the start message to the group chat manager. # The group chat manager will start the group chat by relaying the message to the participants # and the closure agent. - if messages is not None: - if not isinstance(messages, list): - messages = [messages] - await self._runtime.send_message( GroupChatStart(messages=messages), recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), From efea4339d6a6a7ed9030a43723f9ab6e1b5d2205 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Sun, 15 Dec 2024 00:26:56 +0530 Subject: [PATCH 17/19] Refactor group chat message handling to validate all messages in the start process instead of just the first one --- .../_group_chat/_magentic_one/_magentic_one_orchestrator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index fbd9b3b0816e..dcdf8b91809d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -128,10 +128,10 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No return assert message is not None and message.messages is not None - # Validate the group state given the start message. - await self.validate_group_state([message.messages[0]]) + # Validate the group state given all the messages. + await self.validate_group_state(message.messages) - # Log the start message. + # Log the message. await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) # Outer Loop for first time # Create the initial task ledger From cb0a734d0e82be4fde6f33667e78514a26382d2b Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sat, 14 Dec 2024 21:35:35 -0800 Subject: [PATCH 18/19] Update society of mind agent to use the list input task --- .../agents/_society_of_mind_agent.py | 110 ++++++++++-------- .../src/autogen_agentchat/state/__init__.py | 2 + .../src/autogen_agentchat/state/_states.py | 7 ++ .../tests/test_society_of_mind_agent.py | 21 +++- 4 files changed, 89 insertions(+), 51 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py index 9074760c5ed8..0ca49a990248 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py @@ -1,10 +1,10 @@ -from typing import AsyncGenerator, List, Sequence +from typing import Any, AsyncGenerator, List, Mapping, Sequence -from autogen_core import CancellationToken, Image -from autogen_core.models import ChatCompletionClient -from autogen_core.models._types import SystemMessage +from autogen_core import CancellationToken +from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage from autogen_agentchat.base import Response +from autogen_agentchat.state import SocietyOfMindAgentState from ..base import TaskResult, Team from ..messages import ( @@ -32,6 +32,10 @@ class SocietyOfMindAgent(BaseChatAgent): team (Team): The team of agents to use. model_client (ChatCompletionClient): The model client to use for preparing responses. description (str, optional): The description of the agent. + instruction (str, optional): The instruction to use when generating a response using the inner team's messages. + Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'. + response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages. + Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'. Example: @@ -39,35 +43,51 @@ class SocietyOfMindAgent(BaseChatAgent): .. code-block:: python import asyncio + from autogen_agentchat.ui import Console from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.teams import RoundRobinGroupChat - from autogen_agentchat.conditions import MaxMessageTermination + from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") - agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") - agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") - inner_termination = MaxMessageTermination(3) + agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.") + agent2 = AssistantAgent( + "assistant2", + model_client=model_client, + system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.", + ) + inner_termination = TextMentionTermination("APPROVE") inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination) society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client) - agent3 = AssistantAgent("assistant3", model_client=model_client, system_message="You are a helpful assistant.") - agent4 = AssistantAgent("assistant4", model_client=model_client, system_message="You are a helpful assistant.") - outter_termination = MaxMessageTermination(10) - team = RoundRobinGroupChat([society_of_mind_agent, agent3, agent4], termination_condition=outter_termination) + agent3 = AssistantAgent( + "assistant3", model_client=model_client, system_message="Translate the text to Spanish." + ) + team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2) - stream = team.run_stream(task="Tell me a one-liner joke.") - async for message in stream: - print(message) + stream = team.run_stream(task="Write a short story with a surprising ending.") + await Console(stream) asyncio.run(main()) """ + DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:" + """str: The default instruction to use when generating a response using the + inner team's messages. The instruction will be prepended to the inner team's + messages when generating a response using the model. It assumes the role of + 'system'.""" + + DEFAULT_RESPONSE_PROMPT = ( + "Output a standalone response to the original request, without mentioning any of the intermediate discussion." + ) + """str: The default response prompt to use when generating a response using + the inner team's messages. It assumes the role of 'system'.""" + def __init__( self, name: str, @@ -75,17 +95,13 @@ def __init__( model_client: ChatCompletionClient, *, description: str = "An agent that uses an inner team of agents to generate responses.", - task_prompt: str = "{transcript}\nContinue.", - response_prompt: str = "Here is a transcript of conversation so far:\n{transcript}\n\\Provide a response to the original request.", + instruction: str = DEFAULT_INSTRUCTION, + response_prompt: str = DEFAULT_RESPONSE_PROMPT, ) -> None: super().__init__(name=name, description=description) self._team = team self._model_client = model_client - if "{transcript}" not in task_prompt: - raise ValueError("The task prompt must contain the '{transcript}' placeholder for the transcript.") - self._task_prompt = task_prompt - if "{transcript}" not in response_prompt: - raise ValueError("The response prompt must contain the '{transcript}' placeholder for the transcript.") + self._instruction = instruction self._response_prompt = response_prompt @property @@ -104,33 +120,41 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - # Build the context. - delta = list(messages) - task: str | None = None - if len(delta) > 0: - task = self._task_prompt.format(transcript=self._create_transcript(delta)) + # Prepare the task for the team of agents. + task = list(messages) # Run the team of agents. result: TaskResult | None = None inner_messages: List[AgentMessage] = [] + count = 0 async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token): if isinstance(inner_msg, TaskResult): result = inner_msg else: + count += 1 + if count <= len(task): + # Skip the task messages. + continue yield inner_msg inner_messages.append(inner_msg) assert result is not None - if len(inner_messages) < 2: - # The first message is the task message so we need at least 2 messages. + if len(inner_messages) == 0: yield Response( chat_message=TextMessage(source=self.name, content="No response."), inner_messages=inner_messages ) else: - prompt = self._response_prompt.format(transcript=self._create_transcript(inner_messages[1:])) - completion = await self._model_client.create( - messages=[SystemMessage(content=prompt)], cancellation_token=cancellation_token + # Generate a response using the model client. + llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)] + llm_messages.extend( + [ + UserMessage(content=message.content, source=message.source) + for message in inner_messages + if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage) + ] ) + llm_messages.append(SystemMessage(content=self._response_prompt)) + completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token) assert isinstance(completion.content, str) yield Response( chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage), @@ -143,17 +167,11 @@ async def on_messages_stream( async def on_reset(self, cancellation_token: CancellationToken) -> None: await self._team.reset() - def _create_transcript(self, messages: Sequence[AgentMessage]) -> str: - transcript = "" - for message in messages: - if isinstance(message, TextMessage | StopMessage | HandoffMessage): - transcript += f"{message.source}: {message.content}\n" - elif isinstance(message, MultiModalMessage): - for content in message.content: - if isinstance(content, Image): - transcript += f"{message.source}: [Image]\n" - else: - transcript += f"{message.source}: {content}\n" - else: - raise ValueError(f"Unexpected message type: {message} in {self.__class__.__name__}") - return transcript + async def save_state(self) -> Mapping[str, Any]: + team_state = await self._team.save_state() + state = SocietyOfMindAgentState(inner_team_state=team_state) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + society_of_mind_state = SocietyOfMindAgentState.model_validate(state) + await self._team.load_state(society_of_mind_state.inner_team_state) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py index abb468a70b62..3cb3efa8145d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py @@ -8,6 +8,7 @@ MagenticOneOrchestratorState, RoundRobinManagerState, SelectorManagerState, + SocietyOfMindAgentState, SwarmManagerState, TeamState, ) @@ -22,4 +23,5 @@ "SwarmManagerState", "MagenticOneOrchestratorState", "TeamState", + "SocietyOfMindAgentState", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py index d54c533c5440..4bf3d4709943 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py @@ -79,3 +79,10 @@ class MagenticOneOrchestratorState(BaseGroupChatManagerState): n_rounds: int = Field(default=0) n_stalls: int = Field(default=0) type: str = Field(default="MagenticOneOrchestratorState") + + +class SocietyOfMindAgentState(BaseState): + """State for a Society of Mind agent.""" + + inner_team_state: Mapping[str, Any] = Field(default_factory=dict) + type: str = Field(default="SocietyOfMindAgentState") diff --git a/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py b/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py index ec4bf08b138c..9bf4713d9c43 100644 --- a/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py +++ b/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py @@ -72,9 +72,20 @@ async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None: inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination) society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client) response = await society_of_mind_agent.run(task="Count to 10.") - assert len(response.messages) == 5 + assert len(response.messages) == 4 assert response.messages[0].source == "user" - assert response.messages[1].source == "user" - assert response.messages[2].source == "assistant1" - assert response.messages[3].source == "assistant2" - assert response.messages[4].source == "society_of_mind" + assert response.messages[1].source == "assistant1" + assert response.messages[2].source == "assistant2" + assert response.messages[3].source == "society_of_mind" + + # Test save and load state. + state = await society_of_mind_agent.save_state() + assert state is not None + agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") + agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") + inner_termination = MaxMessageTermination(3) + inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination) + society_of_mind_agent2 = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client) + await society_of_mind_agent2.load_state(state) + state2 = await society_of_mind_agent2.save_state() + assert state == state2 From 6be1f73511da3ebce35fef87d7350ba3587504db Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sat, 14 Dec 2024 21:44:03 -0800 Subject: [PATCH 19/19] fix doc example --- .../src/autogen_agentchat/agents/_society_of_mind_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py index 0ca49a990248..4c13c10bc386 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py @@ -47,7 +47,7 @@ class SocietyOfMindAgent(BaseChatAgent): from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.teams import RoundRobinGroupChat - from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination + from autogen_agentchat.conditions import TextMentionTermination async def main() -> None: