From 424da3d1ab09e5bb8f7936148e3d2fe4e026b3b5 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Wed, 4 Dec 2024 16:03:19 +0530 Subject: [PATCH] feat: enhance task handling to support single and multiple messages in group chat --- .../src/autogen_agentchat/base/_task.py | 6 ++-- .../teams/_group_chat/_base_group_chat.py | 14 +++++---- .../_group_chat/_base_group_chat_manager.py | 30 ++++++++++++------- .../teams/_group_chat/_events.py | 6 ++-- 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index f617b3823451..d0d9aefe70a7 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import AsyncGenerator, Protocol, Sequence +from typing import AsyncGenerator, List, Protocol, Sequence from autogen_core import CancellationToken @@ -23,7 +23,7 @@ class TaskRunner(Protocol): async def run( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the task and return the result. @@ -36,7 +36,7 @@ async def run( def run_stream( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the task and produces a stream of messages and the final result diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 58f87e5a5f8c..8a9080caa163 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -149,7 +149,11 @@ async def collect_output_messages( if isinstance(message, GroupChatTermination): self._stop_reason = message.message.content return - await self._output_message_queue.put(message.message) + + # Handle single message or list of messages + messages = message.message if isinstance(message.message, List) else [message.message] + for msg in messages: + await self._output_message_queue.put(msg) await ClosureAgent.register_closure( runtime, @@ -164,7 +168,7 @@ async def collect_output_messages( async def run( self, *, - task: str | ChatMessage | list[ChatMessage] | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the team and return the result. The base implementation uses @@ -172,7 +176,7 @@ async def run( Once the team is stopped, the termination condition is reset. Args: - task (str | ChatMessage | list[ChatMessage] | None): The task to run the team with. Can be a string, a single message, or a list of messages. + task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. Setting the cancellation token potentially put the team in an inconsistent state, and it may not reset the termination condition. @@ -263,7 +267,7 @@ async def main() -> None: async def run_stream( self, *, - task: str | ChatMessage | list[ChatMessage] | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the team and produces a stream of messages and the final result @@ -271,7 +275,7 @@ async def run_stream( team is stopped, the termination condition is reset. Args: - task (str | ChatMessage | list[ChatMessage] | None): The task to run the team with. Can be a string, a single message, or a list of messages. + task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. Setting the cancellation token potentially put the team in an inconsistent state, and it may not reset the termination condition. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index aefe4f8d49d8..4b2294da44a9 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -74,20 +74,28 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self.validate_group_state(message.message) if message.message is not None: - # Log the start message. - await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) + messages_to_process = [message.message] if isinstance(message.message, ChatMessage) else message.message - # Relay the start message to the participants. - await self.publish_message( - message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token - ) + # Log and relay each message + for msg in messages_to_process: + # Log the message + await self.publish_message( + GroupChatStart(message=msg), topic_id=DefaultTopicId(type=self._output_topic_type) + ) + + # Relay the message to participants + await self.publish_message( + GroupChatStart(message=msg), + topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=ctx.cancellation_token, + ) - # Append the user message to the message thread. - self._message_thread.append(message.message) + # Append to message thread + self._message_thread.append(msg) - # Check if the conversation should be terminated. + # Check termination condition after processing all messages if self._termination_condition is not None: - stop_message = await self._termination_condition([message.message]) + stop_message = await self._termination_condition(messages_to_process) if stop_message is not None: await self.publish_message( GroupChatTermination(message=stop_message), @@ -97,7 +105,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self._termination_condition.reset() return - # Select a speaker to start the conversation. + # Select a speaker to start/continue the conversation speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) # Link the select speaker future to the cancellation token. ctx.cancellation_token.link_future(speaker_topic_type_future) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py index 4ae4d892cace..fa0f556547b2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py @@ -1,3 +1,5 @@ +from typing import List + from pydantic import BaseModel from ...base import Response @@ -7,8 +9,8 @@ class GroupChatStart(BaseModel): """A request to start a group chat.""" - message: ChatMessage | None = None - """An optional user message to start the group chat.""" + message: ChatMessage | List[ChatMessage] | None = None + """An optional user message or list of messages to start the group chat.""" class GroupChatAgentResponse(BaseModel):