Skip to content

Commit

Permalink
feat: add support for list of messages as team task input and update …
Browse files Browse the repository at this point in the history
…Society of Mind Agent (#4500)

* feat: add support for list of messages as team task input
* Update society of mind agent to use the list input task
---------

Co-authored-by: Copilot <[email protected]>
Co-authored-by: Ryan Sweet <[email protected]>
Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
4 people authored Dec 15, 2024
1 parent c714515 commit 7c0bbf6
Show file tree
Hide file tree
Showing 16 changed files with 360 additions and 133 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args

from autogen_core import CancellationToken

from ..base import ChatAgent, Response, TaskResult
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ..messages import (
AgentMessage,
ChatMessage,
TextMessage,
)
from ..state import BaseState


Expand Down Expand Up @@ -45,8 +49,9 @@ async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in :class:`BaseChatAgent`
simply calls :meth:`on_messages` and yields the messages in the response."""
and the final item is the response. The base implementation in
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
the messages in the response."""
response = await self.on_messages(messages, cancellation_token)
for inner_message in response.inner_messages or []:
yield inner_message
Expand All @@ -55,7 +60,7 @@ async def on_messages_stream(
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
Expand All @@ -69,7 +74,14 @@ async def run(
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
input_messages.append(task)
output_messages.append(task)
else:
Expand All @@ -83,7 +95,7 @@ async def run(
async def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
Expand All @@ -99,7 +111,15 @@ async def run_stream(
input_messages.append(text_msg)
output_messages.append(text_msg)
yield text_msg
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
yield msg
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
input_messages.append(task)
output_messages.append(task)
yield task
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import AsyncGenerator, List, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence

from autogen_core import CancellationToken, Image
from autogen_core.models import ChatCompletionClient
from autogen_core.models._types import SystemMessage
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage

from autogen_agentchat.base import Response
from autogen_agentchat.state import SocietyOfMindAgentState

from ..base import TaskResult, Team
from ..messages import (
Expand Down Expand Up @@ -32,60 +32,76 @@ class SocietyOfMindAgent(BaseChatAgent):
team (Team): The team of agents to use.
model_client (ChatCompletionClient): The model client to use for preparing responses.
description (str, optional): The description of the agent.
instruction (str, optional): The instruction to use when generating a response using the inner team's messages.
Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'.
response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages.
Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'.
Example:
.. code-block:: python
import asyncio
from autogen_agentchat.ui import Console
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.conditions import TextMentionTermination
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
inner_termination = MaxMessageTermination(3)
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.")
agent2 = AssistantAgent(
"assistant2",
model_client=model_client,
system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.",
)
inner_termination = TextMentionTermination("APPROVE")
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
agent3 = AssistantAgent("assistant3", model_client=model_client, system_message="You are a helpful assistant.")
agent4 = AssistantAgent("assistant4", model_client=model_client, system_message="You are a helpful assistant.")
outter_termination = MaxMessageTermination(10)
team = RoundRobinGroupChat([society_of_mind_agent, agent3, agent4], termination_condition=outter_termination)
agent3 = AssistantAgent(
"assistant3", model_client=model_client, system_message="Translate the text to Spanish."
)
team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2)
stream = team.run_stream(task="Tell me a one-liner joke.")
async for message in stream:
print(message)
stream = team.run_stream(task="Write a short story with a surprising ending.")
await Console(stream)
asyncio.run(main())
"""

DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
"""str: The default instruction to use when generating a response using the
inner team's messages. The instruction will be prepended to the inner team's
messages when generating a response using the model. It assumes the role of
'system'."""

DEFAULT_RESPONSE_PROMPT = (
"Output a standalone response to the original request, without mentioning any of the intermediate discussion."
)
"""str: The default response prompt to use when generating a response using
the inner team's messages. It assumes the role of 'system'."""

def __init__(
self,
name: str,
team: Team,
model_client: ChatCompletionClient,
*,
description: str = "An agent that uses an inner team of agents to generate responses.",
task_prompt: str = "{transcript}\nContinue.",
response_prompt: str = "Here is a transcript of conversation so far:\n{transcript}\n\\Provide a response to the original request.",
instruction: str = DEFAULT_INSTRUCTION,
response_prompt: str = DEFAULT_RESPONSE_PROMPT,
) -> None:
super().__init__(name=name, description=description)
self._team = team
self._model_client = model_client
if "{transcript}" not in task_prompt:
raise ValueError("The task prompt must contain the '{transcript}' placeholder for the transcript.")
self._task_prompt = task_prompt
if "{transcript}" not in response_prompt:
raise ValueError("The response prompt must contain the '{transcript}' placeholder for the transcript.")
self._instruction = instruction
self._response_prompt = response_prompt

@property
Expand All @@ -104,33 +120,41 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
# Build the context.
delta = list(messages)
task: str | None = None
if len(delta) > 0:
task = self._task_prompt.format(transcript=self._create_transcript(delta))
# Prepare the task for the team of agents.
task = list(messages)

# Run the team of agents.
result: TaskResult | None = None
inner_messages: List[AgentMessage] = []
count = 0
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
if isinstance(inner_msg, TaskResult):
result = inner_msg
else:
count += 1
if count <= len(task):
# Skip the task messages.
continue
yield inner_msg
inner_messages.append(inner_msg)
assert result is not None

if len(inner_messages) < 2:
# The first message is the task message so we need at least 2 messages.
if len(inner_messages) == 0:
yield Response(
chat_message=TextMessage(source=self.name, content="No response."), inner_messages=inner_messages
)
else:
prompt = self._response_prompt.format(transcript=self._create_transcript(inner_messages[1:]))
completion = await self._model_client.create(
messages=[SystemMessage(content=prompt)], cancellation_token=cancellation_token
# Generate a response using the model client.
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
llm_messages.extend(
[
UserMessage(content=message.content, source=message.source)
for message in inner_messages
if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage)
]
)
llm_messages.append(SystemMessage(content=self._response_prompt))
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
assert isinstance(completion.content, str)
yield Response(
chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage),
Expand All @@ -143,17 +167,11 @@ async def on_messages_stream(
async def on_reset(self, cancellation_token: CancellationToken) -> None:
await self._team.reset()

def _create_transcript(self, messages: Sequence[AgentMessage]) -> str:
transcript = ""
for message in messages:
if isinstance(message, TextMessage | StopMessage | HandoffMessage):
transcript += f"{message.source}: {message.content}\n"
elif isinstance(message, MultiModalMessage):
for content in message.content:
if isinstance(content, Image):
transcript += f"{message.source}: [Image]\n"
else:
transcript += f"{message.source}: {content}\n"
else:
raise ValueError(f"Unexpected message type: {message} in {self.__class__.__name__}")
return transcript
async def save_state(self) -> Mapping[str, Any]:
team_state = await self._team.save_state()
state = SocietyOfMindAgentState(inner_team_state=team_state)
return state.model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
await self._team.load_state(society_of_mind_state.inner_team_state)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, Protocol, Sequence
from typing import AsyncGenerator, List, Protocol, Sequence

from autogen_core import CancellationToken

Expand All @@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
Expand All @@ -36,7 +36,7 @@ async def run(
def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MagenticOneOrchestratorState,
RoundRobinManagerState,
SelectorManagerState,
SocietyOfMindAgentState,
SwarmManagerState,
TeamState,
)
Expand All @@ -22,4 +23,5 @@
"SwarmManagerState",
"MagenticOneOrchestratorState",
"TeamState",
"SocietyOfMindAgentState",
]
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ class MagenticOneOrchestratorState(BaseGroupChatManagerState):
n_rounds: int = Field(default=0)
n_stalls: int = Field(default=0)
type: str = Field(default="MagenticOneOrchestratorState")


class SocietyOfMindAgentState(BaseState):
"""State for a Society of Mind agent."""

inner_team_state: Mapping[str, Any] = Field(default_factory=dict)
type: str = Field(default="SocietyOfMindAgentState")
Loading

0 comments on commit 7c0bbf6

Please sign in to comment.