diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 5b2aed4860c1..c06fb8d6db53 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,10 +1,14 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, List, Mapping, Sequence +from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args from autogen_core import CancellationToken from ..base import ChatAgent, Response, TaskResult -from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage +from ..messages import ( + AgentMessage, + ChatMessage, + TextMessage, +) from ..state import BaseState @@ -45,8 +49,9 @@ async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: """Handles incoming messages and returns a stream of messages and - and the final item is the response. The base implementation in :class:`BaseChatAgent` - simply calls :meth:`on_messages` and yields the messages in the response.""" + and the final item is the response. The base implementation in + :class:`BaseChatAgent` simply calls :meth:`on_messages` and yields + the messages in the response.""" response = await self.on_messages(messages, cancellation_token) for inner_message in response.inner_messages or []: yield inner_message @@ -55,7 +60,7 @@ async def on_messages_stream( async def run( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the agent with the given task and return the result.""" @@ -69,7 +74,14 @@ async def run( text_msg = TextMessage(content=task, source="user") input_messages.append(text_msg) output_messages.append(text_msg) - elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): + elif isinstance(task, list): + for msg in task: + if isinstance(msg, get_args(ChatMessage)[0]): + input_messages.append(msg) + output_messages.append(msg) + else: + raise ValueError(f"Invalid message type in list: {type(msg)}") + elif isinstance(task, get_args(ChatMessage)[0]): input_messages.append(task) output_messages.append(task) else: @@ -83,7 +95,7 @@ async def run( async def run_stream( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the agent with the given task and return a stream of messages @@ -99,7 +111,15 @@ async def run_stream( input_messages.append(text_msg) output_messages.append(text_msg) yield text_msg - elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): + elif isinstance(task, list): + for msg in task: + if isinstance(msg, get_args(ChatMessage)[0]): + input_messages.append(msg) + output_messages.append(msg) + yield msg + else: + raise ValueError(f"Invalid message type in list: {type(msg)}") + elif isinstance(task, get_args(ChatMessage)[0]): input_messages.append(task) output_messages.append(task) yield task diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py index 9074760c5ed8..4c13c10bc386 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py @@ -1,10 +1,10 @@ -from typing import AsyncGenerator, List, Sequence +from typing import Any, AsyncGenerator, List, Mapping, Sequence -from autogen_core import CancellationToken, Image -from autogen_core.models import ChatCompletionClient -from autogen_core.models._types import SystemMessage +from autogen_core import CancellationToken +from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage from autogen_agentchat.base import Response +from autogen_agentchat.state import SocietyOfMindAgentState from ..base import TaskResult, Team from ..messages import ( @@ -32,6 +32,10 @@ class SocietyOfMindAgent(BaseChatAgent): team (Team): The team of agents to use. model_client (ChatCompletionClient): The model client to use for preparing responses. description (str, optional): The description of the agent. + instruction (str, optional): The instruction to use when generating a response using the inner team's messages. + Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'. + response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages. + Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'. Example: @@ -39,35 +43,51 @@ class SocietyOfMindAgent(BaseChatAgent): .. code-block:: python import asyncio + from autogen_agentchat.ui import Console from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.teams import RoundRobinGroupChat - from autogen_agentchat.conditions import MaxMessageTermination + from autogen_agentchat.conditions import TextMentionTermination async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") - agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") - agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") - inner_termination = MaxMessageTermination(3) + agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.") + agent2 = AssistantAgent( + "assistant2", + model_client=model_client, + system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.", + ) + inner_termination = TextMentionTermination("APPROVE") inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination) society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client) - agent3 = AssistantAgent("assistant3", model_client=model_client, system_message="You are a helpful assistant.") - agent4 = AssistantAgent("assistant4", model_client=model_client, system_message="You are a helpful assistant.") - outter_termination = MaxMessageTermination(10) - team = RoundRobinGroupChat([society_of_mind_agent, agent3, agent4], termination_condition=outter_termination) + agent3 = AssistantAgent( + "assistant3", model_client=model_client, system_message="Translate the text to Spanish." + ) + team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2) - stream = team.run_stream(task="Tell me a one-liner joke.") - async for message in stream: - print(message) + stream = team.run_stream(task="Write a short story with a surprising ending.") + await Console(stream) asyncio.run(main()) """ + DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:" + """str: The default instruction to use when generating a response using the + inner team's messages. The instruction will be prepended to the inner team's + messages when generating a response using the model. It assumes the role of + 'system'.""" + + DEFAULT_RESPONSE_PROMPT = ( + "Output a standalone response to the original request, without mentioning any of the intermediate discussion." + ) + """str: The default response prompt to use when generating a response using + the inner team's messages. It assumes the role of 'system'.""" + def __init__( self, name: str, @@ -75,17 +95,13 @@ def __init__( model_client: ChatCompletionClient, *, description: str = "An agent that uses an inner team of agents to generate responses.", - task_prompt: str = "{transcript}\nContinue.", - response_prompt: str = "Here is a transcript of conversation so far:\n{transcript}\n\\Provide a response to the original request.", + instruction: str = DEFAULT_INSTRUCTION, + response_prompt: str = DEFAULT_RESPONSE_PROMPT, ) -> None: super().__init__(name=name, description=description) self._team = team self._model_client = model_client - if "{transcript}" not in task_prompt: - raise ValueError("The task prompt must contain the '{transcript}' placeholder for the transcript.") - self._task_prompt = task_prompt - if "{transcript}" not in response_prompt: - raise ValueError("The response prompt must contain the '{transcript}' placeholder for the transcript.") + self._instruction = instruction self._response_prompt = response_prompt @property @@ -104,33 +120,41 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: - # Build the context. - delta = list(messages) - task: str | None = None - if len(delta) > 0: - task = self._task_prompt.format(transcript=self._create_transcript(delta)) + # Prepare the task for the team of agents. + task = list(messages) # Run the team of agents. result: TaskResult | None = None inner_messages: List[AgentMessage] = [] + count = 0 async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token): if isinstance(inner_msg, TaskResult): result = inner_msg else: + count += 1 + if count <= len(task): + # Skip the task messages. + continue yield inner_msg inner_messages.append(inner_msg) assert result is not None - if len(inner_messages) < 2: - # The first message is the task message so we need at least 2 messages. + if len(inner_messages) == 0: yield Response( chat_message=TextMessage(source=self.name, content="No response."), inner_messages=inner_messages ) else: - prompt = self._response_prompt.format(transcript=self._create_transcript(inner_messages[1:])) - completion = await self._model_client.create( - messages=[SystemMessage(content=prompt)], cancellation_token=cancellation_token + # Generate a response using the model client. + llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)] + llm_messages.extend( + [ + UserMessage(content=message.content, source=message.source) + for message in inner_messages + if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage) + ] ) + llm_messages.append(SystemMessage(content=self._response_prompt)) + completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token) assert isinstance(completion.content, str) yield Response( chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage), @@ -143,17 +167,11 @@ async def on_messages_stream( async def on_reset(self, cancellation_token: CancellationToken) -> None: await self._team.reset() - def _create_transcript(self, messages: Sequence[AgentMessage]) -> str: - transcript = "" - for message in messages: - if isinstance(message, TextMessage | StopMessage | HandoffMessage): - transcript += f"{message.source}: {message.content}\n" - elif isinstance(message, MultiModalMessage): - for content in message.content: - if isinstance(content, Image): - transcript += f"{message.source}: [Image]\n" - else: - transcript += f"{message.source}: {content}\n" - else: - raise ValueError(f"Unexpected message type: {message} in {self.__class__.__name__}") - return transcript + async def save_state(self) -> Mapping[str, Any]: + team_state = await self._team.save_state() + state = SocietyOfMindAgentState(inner_team_state=team_state) + return state.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + society_of_mind_state = SocietyOfMindAgentState.model_validate(state) + await self._team.load_state(society_of_mind_state.inner_team_state) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index f617b3823451..d0d9aefe70a7 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import AsyncGenerator, Protocol, Sequence +from typing import AsyncGenerator, List, Protocol, Sequence from autogen_core import CancellationToken @@ -23,7 +23,7 @@ class TaskRunner(Protocol): async def run( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the task and return the result. @@ -36,7 +36,7 @@ async def run( def run_stream( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the task and produces a stream of messages and the final result diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py index abb468a70b62..3cb3efa8145d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py @@ -8,6 +8,7 @@ MagenticOneOrchestratorState, RoundRobinManagerState, SelectorManagerState, + SocietyOfMindAgentState, SwarmManagerState, TeamState, ) @@ -22,4 +23,5 @@ "SwarmManagerState", "MagenticOneOrchestratorState", "TeamState", + "SocietyOfMindAgentState", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py index d54c533c5440..4bf3d4709943 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py @@ -79,3 +79,10 @@ class MagenticOneOrchestratorState(BaseGroupChatManagerState): n_rounds: int = Field(default=0) n_stalls: int = Field(default=0) type: str = Field(default="MagenticOneOrchestratorState") + + +class SocietyOfMindAgentState(BaseState): + """State for a Society of Mind agent.""" + + inner_team_state: Mapping[str, Any] = Field(default_factory=dict) + type: str = Field(default="SocietyOfMindAgentState") diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index fe405d52ed2a..7a6496acbe87 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -2,7 +2,7 @@ import logging import uuid from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Callable, List, Mapping +from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args from autogen_core import ( AgentId, @@ -19,7 +19,7 @@ from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TaskResult, Team, TerminationCondition -from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage +from ...messages import AgentMessage, ChatMessage, TextMessage from ...state import TeamState from ._chat_agent_container import ChatAgentContainer from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination @@ -146,11 +146,18 @@ async def collect_output_messages( message: GroupChatStart | GroupChatMessage | GroupChatTermination, ctx: MessageContext, ) -> None: - event_logger.info(message.message) - if isinstance(message, GroupChatTermination): + """Collect output messages from the group chat.""" + if isinstance(message, GroupChatStart): + if message.messages is not None: + for msg in message.messages: + event_logger.info(msg) + await self._output_message_queue.put(msg) + elif isinstance(message, GroupChatMessage): + event_logger.info(message.message) + await self._output_message_queue.put(message.message) + elif isinstance(message, GroupChatTermination): + event_logger.info(message.message) self._stop_reason = message.message.content - return - await self._output_message_queue.put(message.message) await ClosureAgent.register_closure( runtime, @@ -165,7 +172,7 @@ async def collect_output_messages( async def run( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the team and return the result. The base implementation uses @@ -173,7 +180,7 @@ async def run( Once the team is stopped, the termination condition is reset. Args: - task (str | ChatMessage | None): The task to run the team with. + task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. Setting the cancellation token potentially put the team in an inconsistent state, and it may not reset the termination condition. @@ -264,7 +271,7 @@ async def main() -> None: async def run_stream( self, *, - task: str | ChatMessage | None = None, + task: str | ChatMessage | List[ChatMessage] | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the team and produces a stream of messages and the final result @@ -272,7 +279,7 @@ async def run_stream( team is stopped, the termination condition is reset. Args: - task (str | ChatMessage | None): The task to run the team with. + task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. Setting the cancellation token potentially put the team in an inconsistent state, and it may not reset the termination condition. @@ -355,16 +362,20 @@ async def main() -> None: """ - # Create the first chat message if the task is a string or a chat message. - first_chat_message: ChatMessage | None = None + # Create the messages list if the task is a string or a chat message. + messages: List[ChatMessage] | None = None if task is None: pass elif isinstance(task, str): - first_chat_message = TextMessage(content=task, source="user") - elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): - first_chat_message = task - else: - raise ValueError(f"Invalid task type: {type(task)}") + messages = [TextMessage(content=task, source="user")] + elif isinstance(task, get_args(ChatMessage)[0]): + messages = [task] # type: ignore + elif isinstance(task, list): + if not task: + raise ValueError("Task list cannot be empty") + if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task): + raise ValueError("All messages in task list must be valid ChatMessage types") + messages = task if self._is_running: raise ValueError("The team is already running, it cannot run again until it is stopped.") @@ -389,7 +400,7 @@ async def stop_runtime() -> None: # The group chat manager will start the group chat by relaying the message to the participants # and the closure agent. await self._runtime.send_message( - GroupChatStart(message=first_chat_message), + GroupChatStart(messages=messages), recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), cancellation_token=cancellation_token, ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index aefe4f8d49d8..84725cecdd65 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -70,24 +70,28 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No # Stop the group chat. return - # Validate the group state given the start message. - await self.validate_group_state(message.message) + # Validate the group state given the start messages + await self.validate_group_state(message.messages) - if message.message is not None: - # Log the start message. - await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) + if message.messages is not None: + # Log all messages at once + await self.publish_message( + GroupChatStart(messages=message.messages), topic_id=DefaultTopicId(type=self._output_topic_type) + ) - # Relay the start message to the participants. + # Relay all messages at once to participants await self.publish_message( - message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token + GroupChatStart(messages=message.messages), + topic_id=DefaultTopicId(type=self._group_topic_type), + cancellation_token=ctx.cancellation_token, ) - # Append the user message to the message thread. - self._message_thread.append(message.message) + # Append all messages to thread + self._message_thread.extend(message.messages) - # Check if the conversation should be terminated. + # Check termination condition after processing all messages if self._termination_condition is not None: - stop_message = await self._termination_condition([message.message]) + stop_message = await self._termination_condition(message.messages) if stop_message is not None: await self.publish_message( GroupChatTermination(message=stop_message), @@ -97,7 +101,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No await self._termination_condition.reset() return - # Select a speaker to start the conversation. + # Select a speaker to start/continue the conversation speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) # Link the select speaker future to the cancellation token. ctx.cancellation_token.link_future(speaker_topic_type_future) @@ -166,8 +170,13 @@ async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> No await self.reset() @abstractmethod - async def validate_group_state(self, message: ChatMessage | None) -> None: - """Validate the state of the group chat given the start message. This is executed when the group chat manager receives a GroupChatStart event.""" + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: + """Validate the state of the group chat given the start messages. + This is executed when the group chat manager receives a GroupChatStart event. + + Args: + messages: A list of chat messages to validate, or None if no messages are provided. + """ ... @abstractmethod diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index fdf5428b3b5c..f01465d4c3d5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -30,8 +30,8 @@ def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAg @event async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: """Handle a start event by appending the content to the buffer.""" - if message.message is not None: - self._message_buffer.append(message.message) + if message.messages is not None: + self._message_buffer.extend(message.messages) @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py index 4ae4d892cace..ed325fcb5159 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py @@ -1,3 +1,5 @@ +from typing import List + from pydantic import BaseModel from ...base import Response @@ -7,8 +9,8 @@ class GroupChatStart(BaseModel): """A request to start a group chat.""" - message: ChatMessage | None = None - """An optional user message to start the group chat.""" + messages: List[ChatMessage] | None = None + """An optional list of messages to start the group chat.""" class GroupChatAgentResponse(BaseModel): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 5a910e285dbb..dcdf8b91809d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -126,17 +126,18 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No ) # Stop the group chat. return - assert message is not None and message.message is not None + assert message is not None and message.messages is not None - # Validate the group state given the start message. - await self.validate_group_state(message.message) + # Validate the group state given all the messages. + await self.validate_group_state(message.messages) - # Log the start message. + # Log the message. await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) # Outer Loop for first time # Create the initial task ledger ################################# - self._task = self._content_to_str(message.message.content) + # Combine all message contents for task + self._task = " ".join([self._content_to_str(msg.content) for msg in message.messages]) planning_conversation: List[LLMMessage] = [] # 1. GATHER FACTS @@ -184,7 +185,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess return await self._orchestrate_step(ctx.cancellation_token) - async def validate_group_state(self, message: ChatMessage | None) -> None: + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: pass async def save_state(self) -> Mapping[str, Any]: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index 8330c89a3f12..3e17943b90b2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -29,7 +29,7 @@ def __init__( ) self._next_speaker_index = 0 - async def validate_group_state(self, message: ChatMessage | None) -> None: + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: pass async def reset(self) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index a42b6fc2bfbb..5f161d0c6858 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -54,7 +54,7 @@ def __init__( self._allow_repeated_speaker = allow_repeated_speaker self._selector_func = selector_func - async def validate_group_state(self, message: ChatMessage | None) -> None: + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: pass async def reset(self) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index a2659636fc09..436fe8e4cdae 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -29,16 +29,19 @@ def __init__( ) self._current_speaker = participant_topic_types[0] - async def validate_group_state(self, message: ChatMessage | None) -> None: - """Validate the start message for the group chat.""" - # Check if the start message is a handoff message. - if isinstance(message, HandoffMessage): - if message.target not in self._participant_topic_types: - raise ValueError( - f"The target {message.target} is not one of the participants {self._participant_topic_types}. " - "If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target." - ) - return + async def validate_group_state(self, messages: List[ChatMessage] | None) -> None: + """Validate the start messages for the group chat.""" + # Check if any of the start messages is a handoff message. + if messages: + for message in messages: + if isinstance(message, HandoffMessage): + if message.target not in self._participant_topic_types: + raise ValueError( + f"The target {message.target} is not one of the participants {self._participant_topic_types}. " + "If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target." + ) + return + # Check if there is a handoff message in the thread that is not targeting a valid participant. for existing_message in reversed(self._message_thread): if isinstance(existing_message, HandoffMessage): diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 8f0b2d00cb51..c132e3a4862c 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -8,6 +8,7 @@ from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.base import Handoff, TaskResult from autogen_agentchat.messages import ( + ChatMessage, HandoffMessage, MultiModalMessage, TextMessage, @@ -21,7 +22,10 @@ from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_message import ChatCompletionMessage -from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) from openai.types.completion_usage import CompletionUsage from utils import FileLogHandler @@ -33,14 +37,14 @@ class _MockChatCompletion: def __init__(self, chat_completions: List[ChatCompletion]) -> None: self._saved_chat_completions = chat_completions - self._curr_index = 0 + self.curr_index = 0 async def mock_create( self, *args: Any, **kwargs: Any ) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: await asyncio.sleep(0.1) - completion = self._saved_chat_completions[self._curr_index] - self._curr_index += 1 + completion = self._saved_chat_completions[self.curr_index] + self.curr_index += 1 return completion @@ -90,7 +94,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: ChatCompletion( id="id2", choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")) + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="pass", role="assistant"), + ) ], created=0, model=model, @@ -101,7 +109,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: id="id2", choices=[ Choice( - finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant") + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="TERMINATE", role="assistant"), ) ], created=0, @@ -115,7 +125,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: agent = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], ) result = await agent.run(task="task") @@ -133,14 +147,14 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[3].models_usage is None # Test streaming. - mock._curr_index = 0 # pyright: ignore + mock.curr_index = 0 # Reset the mock index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): assert message == result else: assert message == result.messages[index] - index += 1 + index += 1 # Test state saving and loading. state = await agent.save_state() @@ -234,7 +248,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> assert result.messages[3].models_usage.prompt_tokens == 10 # Test streaming. - mock._curr_index = 0 # pyright: ignore + mock.curr_index = 0 # pyright: ignore index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): @@ -248,7 +262,11 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> agent2 = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], ) await agent2.load_state(state) state2 = await agent2.save_state() @@ -293,7 +311,11 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: tool_use_agent = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], handoffs=[handoff], ) assert HandoffMessage in tool_use_agent.produced_message_types @@ -313,7 +335,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: assert result.messages[3].models_usage is None # Test streaming. - mock._curr_index = 0 # pyright: ignore + mock.curr_index = 0 # pyright: ignore index = 0 async for message in tool_use_agent.run_stream(task="task"): if isinstance(message, TaskResult): @@ -330,7 +352,11 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: ChatCompletion( id="id2", choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")) + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="Hello", role="assistant"), + ) ], created=0, model=model, @@ -340,7 +366,10 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: ] mock = _MockChatCompletion(chat_completions) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) - agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key="")) + agent = AssistantAgent( + name="assistant", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + ) # Generate a random base64 image. img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) @@ -351,14 +380,24 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: async def test_invalid_model_capabilities() -> None: model = "random-model" model_client = OpenAIChatCompletionClient( - model=model, api_key="", model_capabilities={"vision": False, "function_calling": False, "json_output": False} + model=model, + api_key="", + model_capabilities={ + "vision": False, + "function_calling": False, + "json_output": False, + }, ) with pytest.raises(ValueError): agent = AssistantAgent( name="assistant", model_client=model_client, - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[ + _pass_function, + _fail_function, + FunctionTool(_echo_function, description="Echo"), + ], ) with pytest.raises(ValueError): @@ -369,3 +408,62 @@ async def test_invalid_model_capabilities() -> None: # Generate a random base64 image. img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) + + +@pytest.mark.asyncio +async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="Response to message 1", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + agent = AssistantAgent( + "test_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + ) + + # Create a list of chat messages + messages: List[ChatMessage] = [ + TextMessage(content="Message 1", source="user"), + TextMessage(content="Message 2", source="user"), + ] + + # Test run method with list of messages + result = await agent.run(task=messages) + assert len(result.messages) == 3 # 2 input messages + 1 response message + assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].content == "Message 1" + assert result.messages[0].source == "user" + assert isinstance(result.messages[1], TextMessage) + assert result.messages[1].content == "Message 2" + assert result.messages[1].source == "user" + assert isinstance(result.messages[2], TextMessage) + assert result.messages[2].content == "Response to message 1" + assert result.messages[2].source == "test_agent" + assert result.messages[2].models_usage is not None + assert result.messages[2].models_usage.completion_tokens == 5 + assert result.messages[2].models_usage.prompt_tokens == 10 + + # Test run_stream method with list of messages + mock.curr_index = 0 # Reset mock index using public attribute + index = 0 + async for message in agent.run_stream(task=messages): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index c64fcbe3232b..6a3c91b805c3 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -1025,3 +1025,48 @@ async def test_swarm_with_handoff_termination() -> None: assert result.messages[1].content == "Transferred to second_agent." assert result.messages[2].content == "Transferred to third_agent." assert result.messages[3].content == "Transferred to non_existing_agent." + + +@pytest.mark.asyncio +async def test_round_robin_group_chat_with_message_list() -> None: + # Create a simple team with echo agents + agent1 = _EchoAgent("Agent1", "First agent") + agent2 = _EchoAgent("Agent2", "Second agent") + termination = MaxMessageTermination(4) # Stop after 4 messages + team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) + + # Create a list of messages + messages: List[ChatMessage] = [ + TextMessage(content="Message 1", source="user"), + TextMessage(content="Message 2", source="user"), + TextMessage(content="Message 3", source="user"), + ] + + # Run the team with the message list + result = await team.run(task=messages) + + # Verify the messages were processed in order + assert len(result.messages) == 4 # Initial messages + echo until termination + assert result.messages[0].content == "Message 1" # First message + assert result.messages[1].content == "Message 2" # Second message + assert result.messages[2].content == "Message 3" # Third message + assert result.messages[3].content == "Message 1" # Echo from first agent + assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4" + + # Test with streaming + await team.reset() + index = 0 + async for message in team.run_stream(task=messages): + if isinstance(message, TaskResult): + assert message == result + else: + assert message == result.messages[index] + index += 1 + + # Test with invalid message list + with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"): + await team.run(task=["not a message"]) # type: ignore[list-item, arg-type] # intentionally testing invalid input + + # Test with empty message list + with pytest.raises(ValueError, match="Task list cannot be empty"): + await team.run(task=[]) diff --git a/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py b/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py index ec4bf08b138c..9bf4713d9c43 100644 --- a/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py +++ b/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py @@ -72,9 +72,20 @@ async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None: inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination) society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client) response = await society_of_mind_agent.run(task="Count to 10.") - assert len(response.messages) == 5 + assert len(response.messages) == 4 assert response.messages[0].source == "user" - assert response.messages[1].source == "user" - assert response.messages[2].source == "assistant1" - assert response.messages[3].source == "assistant2" - assert response.messages[4].source == "society_of_mind" + assert response.messages[1].source == "assistant1" + assert response.messages[2].source == "assistant2" + assert response.messages[3].source == "society_of_mind" + + # Test save and load state. + state = await society_of_mind_agent.save_state() + assert state is not None + agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") + agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") + inner_termination = MaxMessageTermination(3) + inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination) + society_of_mind_agent2 = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client) + await society_of_mind_agent2.load_state(state) + state2 = await society_of_mind_agent2.save_state() + assert state == state2