Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for list of messages as team task input #4500

Merged
merged 30 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8139f7b
feat: add support for list of messages as team task input
iamarunbrahma Dec 3, 2024
424da3d
feat: enhance task handling to support single and multiple messages i…
iamarunbrahma Dec 4, 2024
9a7e0a6
Update python/packages/autogen-agentchat/src/autogen_agentchat/teams/…
iamarunbrahma Dec 5, 2024
06e3d45
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 5, 2024
c0b34f3
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 6, 2024
71f3db5
Refactor message processing in handle_start to check for ChatMessage …
iamarunbrahma Dec 7, 2024
94c9e13
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 7, 2024
abaf4b9
update _base_group_chat_manager.py, _chat_agent_container.py and _mag…
iamarunbrahma Dec 7, 2024
c08b8cf
Refactor GroupChatStart to support multiple messages; update related …
iamarunbrahma Dec 7, 2024
cfe3aaf
Enhance task handling in _base_chat_agent.py to support a list of Cha…
iamarunbrahma Dec 8, 2024
ac011b8
Refactor validate_group_state method to accept a list of ChatMessages…
iamarunbrahma Dec 8, 2024
9712f2b
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 9, 2024
376f337
feat: add list input support for chat messages with unit tests
iamarunbrahma Dec 9, 2024
9666740
fix: type check for Annotated types
iamarunbrahma Dec 10, 2024
068008e
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 10, 2024
46381ec
fix(test): Update mock chat completion message content to match expec…
iamarunbrahma Dec 10, 2024
cd7a245
fix: ChatMessage to List[ChatMessage]
iamarunbrahma Dec 10, 2024
f2f37c3
Merge branch 'main' into list_messages_support
rysweet Dec 10, 2024
0f875d1
fixed mypy and pyright issues related to type checks
iamarunbrahma Dec 10, 2024
27d542e
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 11, 2024
cbe21af
fix: formatting of _base_group_chat
iamarunbrahma Dec 11, 2024
65d10cb
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 12, 2024
55feb5f
refactor: improve message handling in base group chat
iamarunbrahma Dec 12, 2024
b5c0d1a
refactor: streamline task validation and message handling in base gro…
iamarunbrahma Dec 12, 2024
e4aaccf
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 13, 2024
efea433
Refactor group chat message handling to validate all messages in the …
iamarunbrahma Dec 14, 2024
adbf850
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 14, 2024
cb0a734
Update society of mind agent to use the list input task
ekzhu Dec 15, 2024
0699bca
Merge branch 'main' into list_messages_support
ekzhu Dec 15, 2024
6be1f73
fix doc example
ekzhu Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, Protocol, Sequence
from typing import AsyncGenerator, List, Protocol, Sequence

from autogen_core import CancellationToken

Expand All @@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
Expand All @@ -36,7 +36,7 @@ async def run(
def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,11 @@ async def collect_output_messages(
if isinstance(message, GroupChatTermination):
self._stop_reason = message.message.content
return
await self._output_message_queue.put(message.message)

# Handle single message or list of messages
messages = message.message if isinstance(message.message, List) else [message.message]
for msg in messages:
await self._output_message_queue.put(msg)

await ClosureAgent.register_closure(
runtime,
Expand All @@ -165,15 +169,15 @@ async def collect_output_messages(
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
:meth:`run_stream` to run the team and then returns the final result.
Once the team is stopped, the termination condition is reset.

Args:
task (str | ChatMessage | None): The task to run the team with.
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
Expand Down Expand Up @@ -264,15 +268,15 @@ async def main() -> None:
async def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
of the type :class:`TaskResult` as the last item in the stream. Once the
team is stopped, the termination condition is reset.

Args:
task (str | ChatMessage | None): The task to run the team with.
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
Expand Down Expand Up @@ -361,8 +365,17 @@ async def main() -> None:
pass
elif isinstance(task, str):
first_chat_message = TextMessage(content=task, source="user")
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)):
first_chat_message = task
elif isinstance(task, list):
if not task:
raise ValueError("Task list cannot be empty")
if not all(isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)) for msg in task):
raise ValueError("All messages in task list must be valid ChatMessage types")
first_chat_message = task[0]
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
# Queue remaining messages for processing
for msg in task[1:]:
await self._output_message_queue.put(msg)
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Invalid task type: {type(task)}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,28 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
await self.validate_group_state(message.message)

if message.message is not None:
# Log the start message.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
messages_to_process = [message.message] if isinstance(message.message, ChatMessage) else message.message

# Relay the start message to the participants.
await self.publish_message(
message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token
)
# Log and relay each message
for msg in messages_to_process:
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
# Log the message
await self.publish_message(
GroupChatStart(message=msg), topic_id=DefaultTopicId(type=self._output_topic_type)
)

# Relay the message to participants
await self.publish_message(
GroupChatStart(message=msg),
topic_id=DefaultTopicId(type=self._group_topic_type),
cancellation_token=ctx.cancellation_token,
)

# Append the user message to the message thread.
self._message_thread.append(message.message)
# Append to message thread
self._message_thread.append(msg)

# Check if the conversation should be terminated.
# Check termination condition after processing all messages
if self._termination_condition is not None:
stop_message = await self._termination_condition([message.message])
stop_message = await self._termination_condition([msg])
if stop_message is not None:
await self.publish_message(
GroupChatTermination(message=stop_message),
Expand All @@ -97,7 +105,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
await self._termination_condition.reset()
return

# Select a speaker to start the conversation.
# Select a speaker to start/continue the conversation
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_topic_type_future)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from pydantic import BaseModel

from ...base import Response
Expand All @@ -7,8 +9,8 @@
class GroupChatStart(BaseModel):
"""A request to start a group chat."""

message: ChatMessage | None = None
"""An optional user message to start the group chat."""
message: ChatMessage | List[ChatMessage] | None = None
"""An optional user message or list of messages to start the group chat."""
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved


class GroupChatAgentResponse(BaseModel):
Expand Down
47 changes: 47 additions & 0 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,3 +1033,50 @@ async def test_swarm_with_handoff_termination() -> None:
assert result.messages[1].content == "Transferred to second_agent."
assert result.messages[2].content == "Transferred to third_agent."
assert result.messages[3].content == "Transferred to non_existing_agent."


@pytest.mark.asyncio
async def test_round_robin_group_chat_with_message_list() -> None:
# Create a simple team with echo agents
agent1 = _EchoAgent("Agent1", "First agent")
agent2 = _EchoAgent("Agent2", "Second agent")
termination = MaxMessageTermination(4) # Stop after 4 messages
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)

# Create a list of messages
messages = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
TextMessage(content="Message 3", source="user"),
]

# Run the team with the message list
result = await team.run(task=messages)

# Verify the messages were processed in order
assert len(result.messages) == 6 # All initial messages + echoes until termination
assert result.messages[0].content == "Message 2" # Second message from queue
assert result.messages[1].content == "Message 3" # Third message from queue
assert result.messages[2].content == "Message 1" # First message (processed first)
assert result.messages[3].content == "Message 1" # Echo from first agent
assert result.messages[4].content == "Message 1" # Echo from second agent
assert result.messages[5].content == "Message 1" # Echo from first agent (before termination)
assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4"

# Test with streaming
await team.reset()
index = 0
async for message in team.run_stream(task=messages):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

# Test with invalid message list
with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"):
await team.run(task=["not a message"])

# Test with empty message list
with pytest.raises(ValueError, match="Task list cannot be empty"):
await team.run(task=[])