diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 206bcc6e4295..ac21fbaa257d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -531,6 +531,18 @@ def to_text(self) -> str: return str(self.content) +class SelectorEvent(BaseAgentEvent): + """An event emitted from the `SelectorGroupChat`.""" + + content: str + """The content of the event.""" + + type: Literal["SelectorEvent"] = "SelectorEvent" + + def to_text(self) -> str: + return str(self.content) + + class MessageFactory: """:meta private: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 2a7b15889ec3..2af568ac4981 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -5,7 +5,14 @@ from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast from autogen_core import AgentRuntime, Component, ComponentModel -from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + ModelFamily, + SystemMessage, + UserMessage, +) from pydantic import BaseModel from typing_extensions import Self @@ -16,6 +23,8 @@ BaseAgentEvent, BaseChatMessage, MessageFactory, + ModelClientStreamingChunkEvent, + SelectorEvent, ) from ...state import SelectorManagerState from ._base_group_chat import BaseGroupChat @@ -56,6 +65,7 @@ def __init__( max_selector_attempts: int, candidate_func: Optional[CandidateFuncType], emit_team_events: bool, + model_client_streaming: bool = False, ) -> None: super().__init__( name, @@ -79,6 +89,7 @@ def __init__( self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) + self._model_client_streaming = model_client_streaming async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass @@ -194,7 +205,26 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st num_attempts = 0 while num_attempts < max_attempts: num_attempts += 1 - response = await self._model_client.create(messages=select_speaker_messages) + if self._model_client_streaming: + chunk: CreateResult | str = "" + async for _chunk in self._model_client.create_stream(messages=select_speaker_messages): + chunk = _chunk + if self._emit_team_events: + if isinstance(chunk, str): + await self._output_message_queue.put( + ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name) + ) + else: + assert isinstance(chunk, CreateResult) + assert isinstance(chunk.content, str) + await self._output_message_queue.put( + SelectorEvent(content=chunk.content, source=self._name) + ) + # The last chunk must be CreateResult. + assert isinstance(chunk, CreateResult) + response = chunk + else: + response = await self._model_client.create(messages=select_speaker_messages) assert isinstance(response.content, str) select_speaker_messages.append(AssistantMessage(content=response.content, source="selector")) # NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed. @@ -281,6 +311,7 @@ class SelectorGroupChatConfig(BaseModel): # selector_func: ComponentModel | None max_selector_attempts: int = 3 emit_team_events: bool = False + model_client_streaming: bool = False class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -311,6 +342,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`. This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set. emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. + model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False. Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -453,6 +485,7 @@ def __init__( candidate_func: Optional[CandidateFuncType] = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, + model_client_streaming: bool = False, ): super().__init__( participants, @@ -473,6 +506,7 @@ def __init__( self._selector_func = selector_func self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func + self._model_client_streaming = model_client_streaming def _create_group_chat_manager_factory( self, @@ -505,6 +539,7 @@ def _create_group_chat_manager_factory( self._max_selector_attempts, self._candidate_func, self._emit_team_events, + self._model_client_streaming, ) def _to_config(self) -> SelectorGroupChatConfig: @@ -518,6 +553,7 @@ def _to_config(self) -> SelectorGroupChatConfig: max_selector_attempts=self._max_selector_attempts, # selector_func=self._selector_func.dump_component() if self._selector_func else None, emit_team_events=self._emit_team_events, + model_client_streaming=self._model_client_streaming, ) @classmethod @@ -536,4 +572,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self: # if config.selector_func # else None, emit_team_events=config.emit_team_events, + model_client_streaming=config.model_client_streaming, ) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 947d9595ba97..971dec1ce54b 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -13,12 +13,19 @@ CodeExecutorAgent, ) from autogen_agentchat.base import Handoff, Response, TaskResult, TerminationCondition -from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination +from autogen_agentchat.conditions import ( + HandoffTermination, + MaxMessageTermination, + StopMessageTermination, + TextMentionTermination, +) from autogen_agentchat.messages import ( BaseAgentEvent, BaseChatMessage, HandoffMessage, + ModelClientStreamingChunkEvent, MultiModalMessage, + SelectorEvent, SelectSpeakerEvent, StopMessage, StructuredMessage, @@ -1698,3 +1705,54 @@ async def test_structured_message_state_roundtrip(runtime: AgentRuntime | None) ) assert manager1._message_thread == manager2._message_thread # pyright: ignore + + +@pytest.mark.asyncio +async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> None: + model_client = ReplayChatCompletionClient( + ["the agent should be agent2"], + ) + agent2 = _StopAgent("agent2", description="stop agent 2", stop_at=0) + agent3 = _EchoAgent("agent3", description="echo agent 3") + termination = StopMessageTermination() + team = SelectorGroupChat( + participants=[agent2, agent3], + model_client=model_client, + termination_condition=termination, + runtime=runtime, + emit_team_events=True, + model_client_streaming=True, + ) + result = await team.run( + task="Write a program that prints 'Hello, world!'", + ) + + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) + assert isinstance(result.messages[1], SelectorEvent) + assert isinstance(result.messages[2], SelectSpeakerEvent) + assert isinstance(result.messages[3], StopMessage) + + assert result.messages[0].content == "Write a program that prints 'Hello, world!'" + assert result.messages[1].content == "the agent should be agent2" + assert result.messages[2].content == ["agent2"] + assert result.messages[3].source == "agent2" + assert result.stop_reason is not None and result.stop_reason == "Stop message received" + + # Test streaming + await team.reset() + model_client.reset() + index = 0 + streaming: List[str] = [] + async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"): + if isinstance(message, TaskResult): + assert message == result + elif isinstance(message, ModelClientStreamingChunkEvent): + streaming.append(message.content) + else: + if streaming: + assert isinstance(message, SelectorEvent) + assert message.content == "".join([chunk for chunk in streaming]) + streaming = [] + assert message == result.messages[index] + index += 1