Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for list of messages as team task input #4500

Merged
merged 30 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8139f7b
feat: add support for list of messages as team task input
iamarunbrahma Dec 3, 2024
424da3d
feat: enhance task handling to support single and multiple messages i…
iamarunbrahma Dec 4, 2024
9a7e0a6
Update python/packages/autogen-agentchat/src/autogen_agentchat/teams/…
iamarunbrahma Dec 5, 2024
06e3d45
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 5, 2024
c0b34f3
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 6, 2024
71f3db5
Refactor message processing in handle_start to check for ChatMessage …
iamarunbrahma Dec 7, 2024
94c9e13
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 7, 2024
abaf4b9
update _base_group_chat_manager.py, _chat_agent_container.py and _mag…
iamarunbrahma Dec 7, 2024
c08b8cf
Refactor GroupChatStart to support multiple messages; update related …
iamarunbrahma Dec 7, 2024
cfe3aaf
Enhance task handling in _base_chat_agent.py to support a list of Cha…
iamarunbrahma Dec 8, 2024
ac011b8
Refactor validate_group_state method to accept a list of ChatMessages…
iamarunbrahma Dec 8, 2024
9712f2b
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 9, 2024
376f337
feat: add list input support for chat messages with unit tests
iamarunbrahma Dec 9, 2024
9666740
fix: type check for Annotated types
iamarunbrahma Dec 10, 2024
068008e
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 10, 2024
46381ec
fix(test): Update mock chat completion message content to match expec…
iamarunbrahma Dec 10, 2024
cd7a245
fix: ChatMessage to List[ChatMessage]
iamarunbrahma Dec 10, 2024
f2f37c3
Merge branch 'main' into list_messages_support
rysweet Dec 10, 2024
0f875d1
fixed mypy and pyright issues related to type checks
iamarunbrahma Dec 10, 2024
27d542e
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 11, 2024
cbe21af
fix: formatting of _base_group_chat
iamarunbrahma Dec 11, 2024
65d10cb
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 12, 2024
55feb5f
refactor: improve message handling in base group chat
iamarunbrahma Dec 12, 2024
b5c0d1a
refactor: streamline task validation and message handling in base gro…
iamarunbrahma Dec 12, 2024
e4aaccf
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 13, 2024
efea433
Refactor group chat message handling to validate all messages in the …
iamarunbrahma Dec 14, 2024
adbf850
Merge branch 'main' into list_messages_support
iamarunbrahma Dec 14, 2024
cb0a734
Update society of mind agent to use the list input task
ekzhu Dec 15, 2024
0699bca
Merge branch 'main' into list_messages_support
ekzhu Dec 15, 2024
6be1f73
fix doc example
ekzhu Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args

from autogen_core import CancellationToken

from ..base import ChatAgent, Response, TaskResult
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ..messages import (
AgentMessage,
ChatMessage,
TextMessage,
)
from ..state import BaseState


Expand Down Expand Up @@ -45,8 +49,9 @@ async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in :class:`BaseChatAgent`
simply calls :meth:`on_messages` and yields the messages in the response."""
and the final item is the response. The base implementation in
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
the messages in the response."""
response = await self.on_messages(messages, cancellation_token)
for inner_message in response.inner_messages or []:
yield inner_message
Expand All @@ -55,7 +60,7 @@ async def on_messages_stream(
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
Expand All @@ -69,7 +74,14 @@ async def run(
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
input_messages.append(task)
output_messages.append(task)
else:
Expand All @@ -83,7 +95,7 @@ async def run(
async def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
Expand All @@ -99,7 +111,15 @@ async def run_stream(
input_messages.append(text_msg)
output_messages.append(text_msg)
yield text_msg
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
yield msg
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
input_messages.append(task)
output_messages.append(task)
yield task
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, Protocol, Sequence
from typing import AsyncGenerator, List, Protocol, Sequence

from autogen_core import CancellationToken

Expand All @@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
Expand All @@ -36,7 +36,7 @@ async def run(
def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, List, Mapping
from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args

from autogen_core import (
AgentId,
Expand All @@ -19,7 +19,7 @@

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ...messages import AgentMessage, ChatMessage, TextMessage
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
Expand Down Expand Up @@ -146,11 +146,18 @@ async def collect_output_messages(
message: GroupChatStart | GroupChatMessage | GroupChatTermination,
ctx: MessageContext,
) -> None:
event_logger.info(message.message)
if isinstance(message, GroupChatTermination):
"""Collect output messages from the group chat."""
if isinstance(message, GroupChatStart):
if message.messages is not None:
for msg in message.messages:
event_logger.info(msg)
await self._output_message_queue.put(msg)
elif isinstance(message, GroupChatMessage):
event_logger.info(message.message)
await self._output_message_queue.put(message.message)
elif isinstance(message, GroupChatTermination):
event_logger.info(message.message)
self._stop_reason = message.message.content
return
await self._output_message_queue.put(message.message)

await ClosureAgent.register_closure(
runtime,
Expand All @@ -165,15 +172,15 @@ async def collect_output_messages(
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
:meth:`run_stream` to run the team and then returns the final result.
Once the team is stopped, the termination condition is reset.

Args:
task (str | ChatMessage | None): The task to run the team with.
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
Expand Down Expand Up @@ -264,15 +271,15 @@ async def main() -> None:
async def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
of the type :class:`TaskResult` as the last item in the stream. Once the
team is stopped, the termination condition is reset.

Args:
task (str | ChatMessage | None): The task to run the team with.
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
Expand Down Expand Up @@ -356,13 +363,22 @@ async def main() -> None:
"""

# Create the first chat message if the task is a string or a chat message.
first_chat_message: ChatMessage | None = None
first_chat_message: ChatMessage | List[ChatMessage] | None = None
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
if task is None:
pass
elif isinstance(task, str):
first_chat_message = TextMessage(content=task, source="user")
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, get_args(ChatMessage)[0]):
first_chat_message = task
elif isinstance(task, list):
if not task:
raise ValueError("Task list cannot be empty")
if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task):
raise ValueError("All messages in task list must be valid ChatMessage types")
first_chat_message = task[0]
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
# Queue remaining messages for processing
for msg in task[1:]:
await self._output_message_queue.put(msg)
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Invalid task type: {type(task)}")

Expand All @@ -388,8 +404,15 @@ async def stop_runtime() -> None:
# Run the team by sending the start message to the group chat manager.
# The group chat manager will start the group chat by relaying the message to the participants
# and the closure agent.
if first_chat_message is not None:
if isinstance(first_chat_message, list):
messages = first_chat_message
else:
messages = [first_chat_message]
else:
messages = None
await self._runtime.send_message(
GroupChatStart(message=first_chat_message),
GroupChatStart(messages=messages),
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
cancellation_token=cancellation_token,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,28 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
# Stop the group chat.
return

# Validate the group state given the start message.
await self.validate_group_state(message.message)
# Validate the group state given the start messages
await self.validate_group_state(message.messages)

if message.message is not None:
# Log the start message.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
if message.messages is not None:
# Log all messages at once
await self.publish_message(
GroupChatStart(messages=message.messages), topic_id=DefaultTopicId(type=self._output_topic_type)
)

# Relay the start message to the participants.
# Relay all messages at once to participants
await self.publish_message(
message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token
GroupChatStart(messages=message.messages),
topic_id=DefaultTopicId(type=self._group_topic_type),
cancellation_token=ctx.cancellation_token,
)

# Append the user message to the message thread.
self._message_thread.append(message.message)
# Append all messages to thread
self._message_thread.extend(message.messages)

# Check if the conversation should be terminated.
# Check termination condition after processing all messages
if self._termination_condition is not None:
stop_message = await self._termination_condition([message.message])
stop_message = await self._termination_condition(message.messages)
if stop_message is not None:
await self.publish_message(
GroupChatTermination(message=stop_message),
Expand All @@ -97,7 +101,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
await self._termination_condition.reset()
return

# Select a speaker to start the conversation.
# Select a speaker to start/continue the conversation
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_topic_type_future)
Expand Down Expand Up @@ -166,8 +170,13 @@ async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> No
await self.reset()

@abstractmethod
async def validate_group_state(self, message: ChatMessage | None) -> None:
"""Validate the state of the group chat given the start message. This is executed when the group chat manager receives a GroupChatStart event."""
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
"""Validate the state of the group chat given the start messages.
This is executed when the group chat manager receives a GroupChatStart event.

Args:
messages: A list of chat messages to validate, or None if no messages are provided.
"""
...

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAg
@event
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
"""Handle a start event by appending the content to the buffer."""
if message.message is not None:
self._message_buffer.append(message.message)
if message.messages is not None:
self._message_buffer.extend(message.messages)

@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from pydantic import BaseModel

from ...base import Response
Expand All @@ -7,8 +9,8 @@
class GroupChatStart(BaseModel):
"""A request to start a group chat."""

message: ChatMessage | None = None
"""An optional user message to start the group chat."""
messages: List[ChatMessage] | None = None
"""An optional list of messages to start the group chat."""


class GroupChatAgentResponse(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,18 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
)
# Stop the group chat.
return
assert message is not None and message.message is not None
assert message is not None and message.messages is not None

# Validate the group state given the start message.
await self.validate_group_state(message.message)
await self.validate_group_state([message.messages[0]])
iamarunbrahma marked this conversation as resolved.
Show resolved Hide resolved

# Log the start message.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Outer Loop for first time
# Create the initial task ledger
#################################
self._task = self._content_to_str(message.message.content)
# Combine all message contents for task
self._task = " ".join([self._content_to_str(msg.content) for msg in message.messages])
planning_conversation: List[LLMMessage] = []

# 1. GATHER FACTS
Expand Down Expand Up @@ -184,7 +185,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess
return
await self._orchestrate_step(ctx.cancellation_token)

async def validate_group_state(self, message: ChatMessage | None) -> None:
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass

async def save_state(self) -> Mapping[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
)
self._next_speaker_index = 0

async def validate_group_state(self, message: ChatMessage | None) -> None:
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass

async def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func

async def validate_group_state(self, message: ChatMessage | None) -> None:
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass

async def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ def __init__(
)
self._current_speaker = participant_topic_types[0]

async def validate_group_state(self, message: ChatMessage | None) -> None:
"""Validate the start message for the group chat."""
# Check if the start message is a handoff message.
if isinstance(message, HandoffMessage):
if message.target not in self._participant_topic_types:
raise ValueError(
f"The target {message.target} is not one of the participants {self._participant_topic_types}. "
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
)
return
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
"""Validate the start messages for the group chat."""
# Check if any of the start messages is a handoff message.
if messages:
for message in messages:
if isinstance(message, HandoffMessage):
if message.target not in self._participant_topic_types:
raise ValueError(
f"The target {message.target} is not one of the participants {self._participant_topic_types}. "
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
)
return

# Check if there is a handoff message in the thread that is not targeting a valid participant.
for existing_message in reversed(self._message_thread):
if isinstance(existing_message, HandoffMessage):
Expand Down
Loading
Loading