Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for list of messages as team task input #4500

Merged
merged 30 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8139f7b
feat: add support for list of messages as team task input
iamarunbrahma Dec 3, 2024
424da3d
feat: enhance task handling to support single and multiple messages i…
iamarunbrahma Dec 4, 2024
9a7e0a6
Update python/packages/autogen-agentchat/src/autogen_agentchat/teams/…
iamarunbrahma Dec 5, 2024
06e3d45
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 5, 2024
c0b34f3
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 6, 2024
71f3db5
Refactor message processing in handle_start to check for ChatMessage …
iamarunbrahma Dec 7, 2024
94c9e13
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 7, 2024
abaf4b9
update _base_group_chat_manager.py, _chat_agent_container.py and _mag…
iamarunbrahma Dec 7, 2024
c08b8cf
Refactor GroupChatStart to support multiple messages; update related …
iamarunbrahma Dec 7, 2024
cfe3aaf
Enhance task handling in _base_chat_agent.py to support a list of Cha…
iamarunbrahma Dec 8, 2024
ac011b8
Refactor validate_group_state method to accept a list of ChatMessages…
iamarunbrahma Dec 8, 2024
9712f2b
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 9, 2024
376f337
feat: add list input support for chat messages with unit tests
iamarunbrahma Dec 9, 2024
9666740
fix: type check for Annotated types
iamarunbrahma Dec 10, 2024
068008e
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 10, 2024
46381ec
fix(test): Update mock chat completion message content to match expec…
iamarunbrahma Dec 10, 2024
cd7a245
fix: ChatMessage to List[ChatMessage]
iamarunbrahma Dec 10, 2024
f2f37c3
Merge branch 'main' into list_messages_support
rysweet Dec 10, 2024
0f875d1
fixed mypy and pyright issues related to type checks
iamarunbrahma Dec 10, 2024
27d542e
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 11, 2024
cbe21af
fix: formatting of _base_group_chat
iamarunbrahma Dec 11, 2024
65d10cb
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 12, 2024
55feb5f
refactor: improve message handling in base group chat
iamarunbrahma Dec 12, 2024
b5c0d1a
refactor: streamline task validation and message handling in base gro…
iamarunbrahma Dec 12, 2024
e4aaccf
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 13, 2024
efea433
Refactor group chat message handling to validate all messages in the …
iamarunbrahma Dec 14, 2024
adbf850
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 14, 2024
cb0a734
Update society of mind agent to use the list input task
ekzhu Dec 15, 2024
0699bca
Merge branch 'main' into list_messages_support
ekzhu Dec 15, 2024
6be1f73
fix doc example
ekzhu Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args

from autogen_core import CancellationToken

from ..base import ChatAgent, Response, TaskResult
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ..messages import (
AgentMessage,
ChatMessage,
TextMessage,
)
from ..state import BaseState


Expand Down Expand Up @@ -45,8 +49,9 @@ async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in :class:`BaseChatAgent`
simply calls :meth:`on_messages` and yields the messages in the response."""
and the final item is the response. The base implementation in
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
the messages in the response."""
response = await self.on_messages(messages, cancellation_token)
for inner_message in response.inner_messages or []:
yield inner_message
Expand All @@ -55,7 +60,7 @@ async def on_messages_stream(
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
Expand All @@ -69,7 +74,14 @@ async def run(
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
input_messages.append(task)
output_messages.append(task)
else:
Expand All @@ -83,7 +95,7 @@ async def run(
async def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
Expand All @@ -99,7 +111,15 @@ async def run_stream(
input_messages.append(text_msg)
output_messages.append(text_msg)
yield text_msg
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
yield msg
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
input_messages.append(task)
output_messages.append(task)
yield task
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import AsyncGenerator, List, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence

from autogen_core import CancellationToken, Image
from autogen_core.models import ChatCompletionClient
from autogen_core.models._types import SystemMessage
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage

from autogen_agentchat.base import Response
from autogen_agentchat.state import SocietyOfMindAgentState

from ..base import TaskResult, Team
from ..messages import (
Expand Down Expand Up @@ -32,60 +32,76 @@ class SocietyOfMindAgent(BaseChatAgent):
team (Team): The team of agents to use.
model_client (ChatCompletionClient): The model client to use for preparing responses.
description (str, optional): The description of the agent.
instruction (str, optional): The instruction to use when generating a response using the inner team's messages.
Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'.
response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages.
Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'.


Example:

.. code-block:: python

import asyncio
from autogen_agentchat.ui import Console
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.conditions import TextMentionTermination


async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")

agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
inner_termination = MaxMessageTermination(3)
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.")
agent2 = AssistantAgent(
"assistant2",
model_client=model_client,
system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.",
)
inner_termination = TextMentionTermination("APPROVE")
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)

society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)

agent3 = AssistantAgent("assistant3", model_client=model_client, system_message="You are a helpful assistant.")
agent4 = AssistantAgent("assistant4", model_client=model_client, system_message="You are a helpful assistant.")
outter_termination = MaxMessageTermination(10)
team = RoundRobinGroupChat([society_of_mind_agent, agent3, agent4], termination_condition=outter_termination)
agent3 = AssistantAgent(
"assistant3", model_client=model_client, system_message="Translate the text to Spanish."
)
team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2)

stream = team.run_stream(task="Tell me a one-liner joke.")
async for message in stream:
print(message)
stream = team.run_stream(task="Write a short story with a surprising ending.")
await Console(stream)


asyncio.run(main())
"""

DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
"""str: The default instruction to use when generating a response using the
inner team's messages. The instruction will be prepended to the inner team's
messages when generating a response using the model. It assumes the role of
'system'."""

DEFAULT_RESPONSE_PROMPT = (
"Output a standalone response to the original request, without mentioning any of the intermediate discussion."
)
"""str: The default response prompt to use when generating a response using
the inner team's messages. It assumes the role of 'system'."""

def __init__(
self,
name: str,
team: Team,
model_client: ChatCompletionClient,
*,
description: str = "An agent that uses an inner team of agents to generate responses.",
task_prompt: str = "{transcript}\nContinue.",
response_prompt: str = "Here is a transcript of conversation so far:\n{transcript}\n\\Provide a response to the original request.",
instruction: str = DEFAULT_INSTRUCTION,
response_prompt: str = DEFAULT_RESPONSE_PROMPT,
) -> None:
super().__init__(name=name, description=description)
self._team = team
self._model_client = model_client
if "{transcript}" not in task_prompt:
raise ValueError("The task prompt must contain the '{transcript}' placeholder for the transcript.")
self._task_prompt = task_prompt
if "{transcript}" not in response_prompt:
raise ValueError("The response prompt must contain the '{transcript}' placeholder for the transcript.")
self._instruction = instruction
self._response_prompt = response_prompt

@property
Expand All @@ -104,33 +120,41 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
# Build the context.
delta = list(messages)
task: str | None = None
if len(delta) > 0:
task = self._task_prompt.format(transcript=self._create_transcript(delta))
# Prepare the task for the team of agents.
task = list(messages)

# Run the team of agents.
result: TaskResult | None = None
inner_messages: List[AgentMessage] = []
count = 0
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
if isinstance(inner_msg, TaskResult):
result = inner_msg
else:
count += 1
if count <= len(task):
# Skip the task messages.
continue
yield inner_msg
inner_messages.append(inner_msg)
assert result is not None

if len(inner_messages) < 2:
# The first message is the task message so we need at least 2 messages.
if len(inner_messages) == 0:
yield Response(
chat_message=TextMessage(source=self.name, content="No response."), inner_messages=inner_messages
)
else:
prompt = self._response_prompt.format(transcript=self._create_transcript(inner_messages[1:]))
completion = await self._model_client.create(
messages=[SystemMessage(content=prompt)], cancellation_token=cancellation_token
# Generate a response using the model client.
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
llm_messages.extend(
[
UserMessage(content=message.content, source=message.source)
for message in inner_messages
if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage)
]
)
llm_messages.append(SystemMessage(content=self._response_prompt))
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
assert isinstance(completion.content, str)
yield Response(
chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage),
Expand All @@ -143,17 +167,11 @@ async def on_messages_stream(
async def on_reset(self, cancellation_token: CancellationToken) -> None:
await self._team.reset()

def _create_transcript(self, messages: Sequence[AgentMessage]) -> str:
transcript = ""
for message in messages:
if isinstance(message, TextMessage | StopMessage | HandoffMessage):
transcript += f"{message.source}: {message.content}\n"
elif isinstance(message, MultiModalMessage):
for content in message.content:
if isinstance(content, Image):
transcript += f"{message.source}: [Image]\n"
else:
transcript += f"{message.source}: {content}\n"
else:
raise ValueError(f"Unexpected message type: {message} in {self.__class__.__name__}")
return transcript
async def save_state(self) -> Mapping[str, Any]:
team_state = await self._team.save_state()
state = SocietyOfMindAgentState(inner_team_state=team_state)
return state.model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
await self._team.load_state(society_of_mind_state.inner_team_state)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, Protocol, Sequence
from typing import AsyncGenerator, List, Protocol, Sequence

from autogen_core import CancellationToken

Expand All @@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
Expand All @@ -36,7 +36,7 @@ async def run(
def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MagenticOneOrchestratorState,
RoundRobinManagerState,
SelectorManagerState,
SocietyOfMindAgentState,
SwarmManagerState,
TeamState,
)
Expand All @@ -22,4 +23,5 @@
"SwarmManagerState",
"MagenticOneOrchestratorState",
"TeamState",
"SocietyOfMindAgentState",
]
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ class MagenticOneOrchestratorState(BaseGroupChatManagerState):
n_rounds: int = Field(default=0)
n_stalls: int = Field(default=0)
type: str = Field(default="MagenticOneOrchestratorState")


class SocietyOfMindAgentState(BaseState):
"""State for a Society of Mind agent."""

inner_team_state: Mapping[str, Any] = Field(default_factory=dict)
type: str = Field(default="SocietyOfMindAgentState")
Loading
Loading