Skip to content

Commit

Permalink
feat: enhance task handling to support single and multiple messages i…
Browse files Browse the repository at this point in the history
…n group chat
  • Loading branch information
iamarunbrahma committed Dec 4, 2024
1 parent 8139f7b commit 424da3d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, Protocol, Sequence
from typing import AsyncGenerator, List, Protocol, Sequence

from autogen_core import CancellationToken

Expand All @@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
Expand All @@ -36,7 +36,7 @@ async def run(
def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ async def collect_output_messages(
if isinstance(message, GroupChatTermination):
self._stop_reason = message.message.content
return
await self._output_message_queue.put(message.message)

# Handle single message or list of messages
messages = message.message if isinstance(message.message, List) else [message.message]
for msg in messages:
await self._output_message_queue.put(msg)

await ClosureAgent.register_closure(
runtime,
Expand All @@ -164,15 +168,15 @@ async def collect_output_messages(
async def run(
self,
*,
task: str | ChatMessage | list[ChatMessage] | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
:meth:`run_stream` to run the team and then returns the final result.
Once the team is stopped, the termination condition is reset.
Args:
task (str | ChatMessage | list[ChatMessage] | None): The task to run the team with. Can be a string, a single message, or a list of messages.
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
Expand Down Expand Up @@ -263,15 +267,15 @@ async def main() -> None:
async def run_stream(
self,
*,
task: str | ChatMessage | list[ChatMessage] | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
of the type :class:`TaskResult` as the last item in the stream. Once the
team is stopped, the termination condition is reset.
Args:
task (str | ChatMessage | list[ChatMessage] | None): The task to run the team with. Can be a string, a single message, or a list of messages.
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,28 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
await self.validate_group_state(message.message)

if message.message is not None:
# Log the start message.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
messages_to_process = [message.message] if isinstance(message.message, ChatMessage) else message.message

# Relay the start message to the participants.
await self.publish_message(
message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token
)
# Log and relay each message
for msg in messages_to_process:
# Log the message
await self.publish_message(
GroupChatStart(message=msg), topic_id=DefaultTopicId(type=self._output_topic_type)
)

# Relay the message to participants
await self.publish_message(
GroupChatStart(message=msg),
topic_id=DefaultTopicId(type=self._group_topic_type),
cancellation_token=ctx.cancellation_token,
)

# Append the user message to the message thread.
self._message_thread.append(message.message)
# Append to message thread
self._message_thread.append(msg)

# Check if the conversation should be terminated.
# Check termination condition after processing all messages
if self._termination_condition is not None:
stop_message = await self._termination_condition([message.message])
stop_message = await self._termination_condition(messages_to_process)
if stop_message is not None:
await self.publish_message(
GroupChatTermination(message=stop_message),
Expand All @@ -97,7 +105,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
await self._termination_condition.reset()
return

# Select a speaker to start the conversation.
# Select a speaker to start/continue the conversation
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_topic_type_future)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from pydantic import BaseModel

from ...base import Response
Expand All @@ -7,8 +9,8 @@
class GroupChatStart(BaseModel):
"""A request to start a group chat."""

message: ChatMessage | None = None
"""An optional user message to start the group chat."""
message: ChatMessage | List[ChatMessage] | None = None
"""An optional user message or list of messages to start the group chat."""


class GroupChatAgentResponse(BaseModel):
Expand Down

0 comments on commit 424da3d

Please sign in to comment.