From 965d9a461b85aff849dd09e77479527529bdcf30 Mon Sep 17 00:00:00 2001 From: "chiyoung.song" Date: Sat, 12 Apr 2025 14:43:39 +0900 Subject: [PATCH 01/11] FEAT: select group chat could using stream --- .../teams/_group_chat/_selector_group_chat.py | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 1aa5aa337065..cc03b9e6bb9c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -5,7 +5,15 @@ from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast from autogen_core import AgentRuntime, Component, ComponentModel -from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage +from autogen_core.logging import LLMStreamEndEvent +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + ModelFamily, + SystemMessage, + UserMessage, +) from pydantic import BaseModel from typing_extensions import Self @@ -55,6 +63,7 @@ def __init__( selector_func: Optional[SelectorFuncType], max_selector_attempts: int, candidate_func: Optional[CandidateFuncType], + streaming: bool = False, ) -> None: super().__init__( name, @@ -77,6 +86,7 @@ def __init__( self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) + self._streaming = streaming async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass @@ -192,7 +202,18 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st num_attempts = 0 while num_attempts < max_attempts: num_attempts += 1 - response = await self._model_client.create(messages=select_speaker_messages) + if self._streaming: + message: CreateResult | str = "" + async for _message in self._model_client.create_stream(messages=select_speaker_messages): + if isinstance(_message, LLMStreamEndEvent): + break + message = _message + if isinstance(message, CreateResult): + response = message + else: + raise ValueError("Model failed to select a speaker.") + else: + response = await self._model_client.create(messages=select_speaker_messages) assert isinstance(response.content, str) select_speaker_messages.append(AssistantMessage(content=response.content, source="selector")) # NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed. @@ -278,6 +299,7 @@ class SelectorGroupChatConfig(BaseModel): allow_repeated_speaker: bool # selector_func: ComponentModel | None max_selector_attempts: int = 3 + streaming: bool = False class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -307,7 +329,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`. This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set. - + streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False. Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -449,6 +471,7 @@ def __init__( selector_func: Optional[SelectorFuncType] = None, candidate_func: Optional[CandidateFuncType] = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, + streaming: bool = False, ): super().__init__( participants, @@ -468,6 +491,7 @@ def __init__( self._selector_func = selector_func self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func + self._streaming = streaming def _create_group_chat_manager_factory( self, @@ -499,6 +523,7 @@ def _create_group_chat_manager_factory( self._selector_func, self._max_selector_attempts, self._candidate_func, + self._streaming, ) def _to_config(self) -> SelectorGroupChatConfig: From 7aff3b6c220fb4567f66f10dcccbd01154dc4c44 Mon Sep 17 00:00:00 2001 From: "chiyoung.song" Date: Sun, 13 Apr 2025 15:04:38 +0900 Subject: [PATCH 02/11] FIX: delete useless if block --- .../autogen_agentchat/teams/_group_chat/_selector_group_chat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index cc03b9e6bb9c..7fe7c4014dc8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -205,8 +205,6 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st if self._streaming: message: CreateResult | str = "" async for _message in self._model_client.create_stream(messages=select_speaker_messages): - if isinstance(_message, LLMStreamEndEvent): - break message = _message if isinstance(message, CreateResult): response = message From fda2d2808ae41ec6d20f82bb21ca387dba810cff Mon Sep 17 00:00:00 2001 From: "chiyoung.song" Date: Thu, 17 Apr 2025 16:35:05 +0900 Subject: [PATCH 03/11] clean --- .../teams/_group_chat/_selector_group_chat.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 635d0b976b53..cb9a53296594 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -5,7 +5,6 @@ from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast from autogen_core import AgentRuntime, Component, ComponentModel -from autogen_core.logging import LLMStreamEndEvent from autogen_core.models import ( AssistantMessage, ChatCompletionClient, @@ -64,7 +63,7 @@ def __init__( max_selector_attempts: int, candidate_func: Optional[CandidateFuncType], emit_team_events: bool, - streaming: bool = False, + model_client_streaming: bool = False, ) -> None: super().__init__( name, @@ -88,7 +87,7 @@ def __init__( self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) - self._streaming = streaming + self._model_client_streaming = model_client_streaming async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass @@ -204,7 +203,7 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st num_attempts = 0 while num_attempts < max_attempts: num_attempts += 1 - if self._streaming: + if self._model_client_streaming: message: CreateResult | str = "" async for _message in self._model_client.create_stream(messages=select_speaker_messages): message = _message @@ -300,7 +299,7 @@ class SelectorGroupChatConfig(BaseModel): # selector_func: ComponentModel | None max_selector_attempts: int = 3 emit_team_events: bool = False - streaming: bool = False + model_client_streaming: bool = False class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -331,7 +330,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`. This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set. emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. - streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False. + model_client_streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False. Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -474,7 +473,7 @@ def __init__( candidate_func: Optional[CandidateFuncType] = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, - streaming: bool = False, + model_client_streaming: bool = False, ): super().__init__( participants, @@ -495,7 +494,7 @@ def __init__( self._selector_func = selector_func self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func - self._streaming = streaming + self._model_client_streaming = model_client_streaming def _create_group_chat_manager_factory( self, @@ -528,7 +527,7 @@ def _create_group_chat_manager_factory( self._max_selector_attempts, self._candidate_func, self._emit_team_events, - self._streaming, + self._model_client_streaming, ) def _to_config(self) -> SelectorGroupChatConfig: @@ -542,7 +541,7 @@ def _to_config(self) -> SelectorGroupChatConfig: max_selector_attempts=self._max_selector_attempts, # selector_func=self._selector_func.dump_component() if self._selector_func else None, emit_team_events=self._emit_team_events, - streaming=self._streaming, + model_client_streaming=self._model_client_streaming, ) @classmethod @@ -561,5 +560,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self: # if config.selector_func # else None, emit_team_events=config.emit_team_events, - streaming=config.streaming, + model_client_streaming=config.model_client_streaming, ) From 31d0d66ca4b7e290762ef2b045eb04eab38caba4 Mon Sep 17 00:00:00 2001 From: "chiyoung.song" Date: Thu, 17 Apr 2025 18:12:23 +0900 Subject: [PATCH 04/11] done - maybe need to adding testcase --- .../teams/_group_chat/_selector_group_chat.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index cb9a53296594..01af43e59866 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -23,11 +23,14 @@ BaseAgentEvent, BaseChatMessage, MessageFactory, + ModelClientStreamingChunkEvent, ) from ...state import SelectorManagerState from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager -from ._events import GroupChatTermination +from ._events import ( + GroupChatTermination, +) trace_logger = logging.getLogger(TRACE_LOGGER_NAME) @@ -207,6 +210,11 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st message: CreateResult | str = "" async for _message in self._model_client.create_stream(messages=select_speaker_messages): message = _message + if self._emit_team_events: + if isinstance(message, str): + await self._output_message_queue.put( + ModelClientStreamingChunkEvent(content=cast(str, _message), source=self._name) + ) if isinstance(message, CreateResult): response = message else: From d5630327728541801578814a68fb272bae40359a Mon Sep 17 00:00:00 2001 From: "chiyoung.song" Date: Sat, 19 Apr 2025 10:33:15 +0900 Subject: [PATCH 05/11] FIX: adding full message of content of stream. --- .../src/autogen_agentchat/messages.py | 12 ++++++++++++ .../teams/_group_chat/_selector_group_chat.py | 16 +++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 206bcc6e4295..9aa4c5942596 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -529,6 +529,18 @@ class SelectSpeakerEvent(BaseAgentEvent): def to_text(self) -> str: return str(self.content) + + +class SelectorEvent(BaseAgentEvent): + """An event signaling the selection of a specific agent.""" + + content: str + """The names of the selected agent.""" + + type: Literal["SelectorEvent"] = "SelectorEvent" + + def to_text(self) -> str: + return str(self.content) class MessageFactory: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 01af43e59866..97106c46ba6b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -24,6 +24,7 @@ BaseChatMessage, MessageFactory, ModelClientStreamingChunkEvent, + SelectorEvent, ) from ...state import SelectorManagerState from ._base_group_chat import BaseGroupChat @@ -215,10 +216,19 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st await self._output_message_queue.put( ModelClientStreamingChunkEvent(content=cast(str, _message), source=self._name) ) - if isinstance(message, CreateResult): - response = message else: - raise ValueError("Model failed to select a speaker.") + if isinstance(message, CreateResult): + response = message + else: + raise ValueError("Model failed to select a speaker.") + + if isinstance(response.content, str): + if self._emit_team_events: + await self._output_message_queue.put(SelectorEvent(content=response.content, source=self._name)) + else: + response.content = "" # fallback to empty string + if self._emit_team_events: + await self._output_message_queue.put(SelectorEvent(content="Model failed to select a valid content type(it must str)", source=self._name)) else: response = await self._model_client.create(messages=select_speaker_messages) assert isinstance(response.content, str) From 3486d6de74f2109c7df58bf31363e2885c072633 Mon Sep 17 00:00:00 2001 From: "chiyoung.song" Date: Sat, 19 Apr 2025 10:54:13 +0900 Subject: [PATCH 06/11] Clean, Add test --- .../src/autogen_agentchat/messages.py | 2 +- .../teams/_group_chat/_selector_group_chat.py | 17 ++++-- .../tests/test_group_chat.py | 59 +++++++++++++++++++ 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 9aa4c5942596..e50836a9747e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -529,7 +529,7 @@ class SelectSpeakerEvent(BaseAgentEvent): def to_text(self) -> str: return str(self.content) - + class SelectorEvent(BaseAgentEvent): """An event signaling the selection of a specific agent.""" diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 97106c46ba6b..480d5f548405 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -29,9 +29,7 @@ from ...state import SelectorManagerState from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager -from ._events import ( - GroupChatTermination, -) +from ._events import GroupChatTermination trace_logger = logging.getLogger(TRACE_LOGGER_NAME) @@ -221,14 +219,21 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st response = message else: raise ValueError("Model failed to select a speaker.") - + if isinstance(response.content, str): if self._emit_team_events: - await self._output_message_queue.put(SelectorEvent(content=response.content, source=self._name)) + await self._output_message_queue.put( + SelectorEvent(content=response.content, source=self._name) + ) else: response.content = "" # fallback to empty string if self._emit_team_events: - await self._output_message_queue.put(SelectorEvent(content="Model failed to select a valid content type(it must str)", source=self._name)) + await self._output_message_queue.put( + SelectorEvent( + content="Model failed to select a valid content type(it must str)", + source=self._name, + ) + ) else: response = await self._model_client.create(messages=select_speaker_messages) assert isinstance(response.content, str) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 947d9595ba97..91e351d3abc3 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -19,6 +19,7 @@ BaseChatMessage, HandoffMessage, MultiModalMessage, + SelectorEvent, SelectSpeakerEvent, StopMessage, StructuredMessage, @@ -1698,3 +1699,61 @@ async def test_structured_message_state_roundtrip(runtime: AgentRuntime | None) ) assert manager1._message_thread == manager2._message_thread # pyright: ignore + + +@pytest.mark.asyncio +async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> None: + model_client = ReplayChatCompletionClient( + ["agent3", "agent2", "agent1", "agent2", "agent1"], + ) + agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2) + agent2 = _EchoAgent("agent2", description="echo agent 2") + agent3 = _EchoAgent("agent3", description="echo agent 3") + termination = TextMentionTermination("TERMINATE") + team = SelectorGroupChat( + participants=[agent1, agent2, agent3], + model_client=model_client, + termination_condition=termination, + runtime=runtime, + emit_team_events=True, + model_client_streaming=True, + ) + result = await team.run( + task="Write a program that prints 'Hello, world!'", + ) + + assert len(result.messages) == 16 + assert isinstance(result.messages[0], TextMessage) + assert isinstance(result.messages[1], SelectorEvent) + assert isinstance(result.messages[2], SelectSpeakerEvent) + assert isinstance(result.messages[3], TextMessage) + assert isinstance(result.messages[4], SelectorEvent) + assert isinstance(result.messages[5], SelectSpeakerEvent) + assert isinstance(result.messages[6], TextMessage) + assert isinstance(result.messages[7], SelectorEvent) + assert isinstance(result.messages[8], SelectSpeakerEvent) + assert isinstance(result.messages[9], TextMessage) + assert isinstance(result.messages[10], SelectorEvent) + assert isinstance(result.messages[11], SelectSpeakerEvent) + assert isinstance(result.messages[12], TextMessage) + assert isinstance(result.messages[13], SelectorEvent) + assert isinstance(result.messages[14], SelectSpeakerEvent) + assert isinstance(result.messages[15], StopMessage) + + assert result.messages[0].content == "Write a program that prints 'Hello, world!'" + assert result.messages[1].content == "agent3" + assert result.messages[2].content == ["agent3"] + assert result.messages[3].source == "agent3" + assert result.messages[4].content == "agent2" + assert result.messages[5].content == ["agent2"] + assert result.messages[6].source == "agent2" + assert result.messages[7].content == "agent1" + assert result.messages[8].content == ["agent1"] + assert result.messages[9].source == "agent1" + assert result.messages[10].content == "agent2" + assert result.messages[11].content == ["agent2"] + assert result.messages[12].source == "agent2" + assert result.messages[13].content == "agent1" + assert result.messages[14].content == ["agent1"] + assert result.messages[15].source == "agent1" + assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned" From 4abae5aa8d5765d69c41f5c6a538a3c7a99debc2 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 21 Apr 2025 11:37:13 -0700 Subject: [PATCH 07/11] Apply suggestions from code review --- .../autogen-agentchat/src/autogen_agentchat/messages.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index e50836a9747e..ac21fbaa257d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -532,10 +532,10 @@ def to_text(self) -> str: class SelectorEvent(BaseAgentEvent): - """An event signaling the selection of a specific agent.""" + """An event emitted from the `SelectorGroupChat`.""" content: str - """The names of the selected agent.""" + """The content of the event.""" type: Literal["SelectorEvent"] = "SelectorEvent" From e68dbbfda0244e8f6f47f4b6db98614e90c8849b Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 21 Apr 2025 11:40:46 -0700 Subject: [PATCH 08/11] Update python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py --- .../autogen_agentchat/teams/_group_chat/_selector_group_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 480d5f548405..dc1244857f56 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -353,7 +353,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`. This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set. emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. - model_client_streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False. + model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False. Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. From 05c250af0a50487c363c5108e0407dfc78aea7bd Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 21 Apr 2025 12:10:39 -0700 Subject: [PATCH 09/11] Fix tests --- .../teams/_group_chat/_selector_group_chat.py | 44 +++++-------- .../tests/test_group_chat.py | 65 +++++++++---------- 2 files changed, 45 insertions(+), 64 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index dc1244857f56..7108f2067336 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -205,35 +205,23 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st num_attempts = 0 while num_attempts < max_attempts: num_attempts += 1 - if self._model_client_streaming: - message: CreateResult | str = "" - async for _message in self._model_client.create_stream(messages=select_speaker_messages): - message = _message - if self._emit_team_events: - if isinstance(message, str): - await self._output_message_queue.put( - ModelClientStreamingChunkEvent(content=cast(str, _message), source=self._name) - ) - else: - if isinstance(message, CreateResult): - response = message - else: - raise ValueError("Model failed to select a speaker.") - - if isinstance(response.content, str): - if self._emit_team_events: - await self._output_message_queue.put( - SelectorEvent(content=response.content, source=self._name) - ) + if self._model_client_streaming and self._emit_team_events: + chunk: CreateResult | str = "" + async for _chunk in self._model_client.create_stream(messages=select_speaker_messages): + chunk = _chunk + if isinstance(chunk, str): + await self._output_message_queue.put( + ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name) + ) else: - response.content = "" # fallback to empty string - if self._emit_team_events: - await self._output_message_queue.put( - SelectorEvent( - content="Model failed to select a valid content type(it must str)", - source=self._name, - ) - ) + assert isinstance(chunk, CreateResult) + assert isinstance(chunk.content, str) + await self._output_message_queue.put( + SelectorEvent(content=chunk.content, source=self._name) + ) + # The last chunk must be CreateResult. + assert isinstance(chunk, CreateResult) + response = chunk else: response = await self._model_client.create(messages=select_speaker_messages) assert isinstance(response.content, str) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 91e351d3abc3..5307817fd571 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -13,11 +13,12 @@ CodeExecutorAgent, ) from autogen_agentchat.base import Handoff, Response, TaskResult, TerminationCondition -from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination +from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination, StopMessageTermination from autogen_agentchat.messages import ( BaseAgentEvent, BaseChatMessage, HandoffMessage, + ModelClientStreamingChunkEvent, MultiModalMessage, SelectorEvent, SelectSpeakerEvent, @@ -1704,14 +1705,13 @@ async def test_structured_message_state_roundtrip(runtime: AgentRuntime | None) @pytest.mark.asyncio async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> None: model_client = ReplayChatCompletionClient( - ["agent3", "agent2", "agent1", "agent2", "agent1"], + ["the agent should be agent2"], ) - agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2) - agent2 = _EchoAgent("agent2", description="echo agent 2") + agent2 = _StopAgent("agent2", description="stop agent 2", stop_at=0) agent3 = _EchoAgent("agent3", description="echo agent 3") - termination = TextMentionTermination("TERMINATE") + termination = StopMessageTermination() team = SelectorGroupChat( - participants=[agent1, agent2, agent3], + participants=[agent2, agent3], model_client=model_client, termination_condition=termination, runtime=runtime, @@ -1722,38 +1722,31 @@ async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> No task="Write a program that prints 'Hello, world!'", ) - assert len(result.messages) == 16 + assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) assert isinstance(result.messages[1], SelectorEvent) assert isinstance(result.messages[2], SelectSpeakerEvent) - assert isinstance(result.messages[3], TextMessage) - assert isinstance(result.messages[4], SelectorEvent) - assert isinstance(result.messages[5], SelectSpeakerEvent) - assert isinstance(result.messages[6], TextMessage) - assert isinstance(result.messages[7], SelectorEvent) - assert isinstance(result.messages[8], SelectSpeakerEvent) - assert isinstance(result.messages[9], TextMessage) - assert isinstance(result.messages[10], SelectorEvent) - assert isinstance(result.messages[11], SelectSpeakerEvent) - assert isinstance(result.messages[12], TextMessage) - assert isinstance(result.messages[13], SelectorEvent) - assert isinstance(result.messages[14], SelectSpeakerEvent) - assert isinstance(result.messages[15], StopMessage) + assert isinstance(result.messages[3], StopMessage) assert result.messages[0].content == "Write a program that prints 'Hello, world!'" - assert result.messages[1].content == "agent3" - assert result.messages[2].content == ["agent3"] - assert result.messages[3].source == "agent3" - assert result.messages[4].content == "agent2" - assert result.messages[5].content == ["agent2"] - assert result.messages[6].source == "agent2" - assert result.messages[7].content == "agent1" - assert result.messages[8].content == ["agent1"] - assert result.messages[9].source == "agent1" - assert result.messages[10].content == "agent2" - assert result.messages[11].content == ["agent2"] - assert result.messages[12].source == "agent2" - assert result.messages[13].content == "agent1" - assert result.messages[14].content == ["agent1"] - assert result.messages[15].source == "agent1" - assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned" + assert result.messages[1].content == "the agent should be agent2" + assert result.messages[2].content == ["agent2"] + assert result.messages[3].source == "agent2" + assert result.stop_reason is not None and result.stop_reason == "Stop message received" + + # Test streaming + await team.reset() + model_client.reset() + index = 0 + streaming = [] + async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"): + if isinstance(message, TaskResult): + assert message == result + elif isinstance(message, ModelClientStreamingChunkEvent): + streaming.append(message) + else: + if streaming: + assert message.content == "".join([chunk.content for chunk in streaming]) + streaming = [] + assert message == result.messages[index] + index += 1 From c9f75827a43ee2aebd8b07a985ba523163d32c96 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 21 Apr 2025 12:18:23 -0700 Subject: [PATCH 10/11] fix --- .../teams/_group_chat/_selector_group_chat.py | 4 +--- .../autogen-agentchat/tests/test_group_chat.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 7108f2067336..a55c71d4387d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -216,9 +216,7 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st else: assert isinstance(chunk, CreateResult) assert isinstance(chunk.content, str) - await self._output_message_queue.put( - SelectorEvent(content=chunk.content, source=self._name) - ) + await self._output_message_queue.put(SelectorEvent(content=chunk.content, source=self._name)) # The last chunk must be CreateResult. assert isinstance(chunk, CreateResult) response = chunk diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 5307817fd571..971dec1ce54b 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -13,7 +13,12 @@ CodeExecutorAgent, ) from autogen_agentchat.base import Handoff, Response, TaskResult, TerminationCondition -from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination, StopMessageTermination +from autogen_agentchat.conditions import ( + HandoffTermination, + MaxMessageTermination, + StopMessageTermination, + TextMentionTermination, +) from autogen_agentchat.messages import ( BaseAgentEvent, BaseChatMessage, @@ -1738,15 +1743,16 @@ async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> No await team.reset() model_client.reset() index = 0 - streaming = [] + streaming: List[str] = [] async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"): if isinstance(message, TaskResult): assert message == result elif isinstance(message, ModelClientStreamingChunkEvent): - streaming.append(message) + streaming.append(message.content) else: if streaming: - assert message.content == "".join([chunk.content for chunk in streaming]) + assert isinstance(message, SelectorEvent) + assert message.content == "".join([chunk for chunk in streaming]) streaming = [] assert message == result.messages[index] index += 1 From f698469759c145582e3a0d81c292c0a1a88902b6 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 21 Apr 2025 15:31:20 -0700 Subject: [PATCH 11/11] Fix --- .../teams/_group_chat/_selector_group_chat.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index a55c71d4387d..2af568ac4981 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -205,18 +205,21 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st num_attempts = 0 while num_attempts < max_attempts: num_attempts += 1 - if self._model_client_streaming and self._emit_team_events: + if self._model_client_streaming: chunk: CreateResult | str = "" async for _chunk in self._model_client.create_stream(messages=select_speaker_messages): chunk = _chunk - if isinstance(chunk, str): - await self._output_message_queue.put( - ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name) - ) - else: - assert isinstance(chunk, CreateResult) - assert isinstance(chunk.content, str) - await self._output_message_queue.put(SelectorEvent(content=chunk.content, source=self._name)) + if self._emit_team_events: + if isinstance(chunk, str): + await self._output_message_queue.put( + ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name) + ) + else: + assert isinstance(chunk, CreateResult) + assert isinstance(chunk.content, str) + await self._output_message_queue.put( + SelectorEvent(content=chunk.content, source=self._name) + ) # The last chunk must be CreateResult. assert isinstance(chunk, CreateResult) response = chunk