Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
965d9a4
FEAT: select group chat could using stream
SongChiYoung Apr 12, 2025
d42a382
Merge remote-tracking branch 'upstream/main' into feature/model_clien…
SongChiYoung Apr 13, 2025
7aff3b6
FIX: delete useless if block
SongChiYoung Apr 13, 2025
99194d4
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 15, 2025
7e24527
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 16, 2025
55a1341
Merge
SongChiYoung Apr 17, 2025
fda2d28
clean
SongChiYoung Apr 17, 2025
31d0d66
done - maybe need to adding testcase
SongChiYoung Apr 17, 2025
29f2c0f
Merge remote-tracking branch 'upstream/main' into feature/model_clien…
SongChiYoung Apr 17, 2025
cd9e000
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 18, 2025
d563032
FIX: adding full message of content of stream.
SongChiYoung Apr 19, 2025
3486d6d
Clean, Add test
SongChiYoung Apr 19, 2025
548171a
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 19, 2025
4abae5a
Apply suggestions from code review
ekzhu Apr 21, 2025
e68dbbf
Update python/packages/autogen-agentchat/src/autogen_agentchat/teams/…
ekzhu Apr 21, 2025
4fabfce
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
ekzhu Apr 21, 2025
b711739
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 21, 2025
05c250a
Fix tests
ekzhu Apr 21, 2025
c9f7582
fix
ekzhu Apr 21, 2025
ef67c17
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
ekzhu Apr 21, 2025
f698469
Fix
ekzhu Apr 21, 2025
94a8f9b
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
ekzhu Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,18 @@ def to_text(self) -> str:
return str(self.content)


class SelectorEvent(BaseAgentEvent):
"""An event signaling the selection of a specific agent."""
Comment thread
ekzhu marked this conversation as resolved.
Outdated

content: str
"""The names of the selected agent."""
Comment thread
ekzhu marked this conversation as resolved.
Outdated

type: Literal["SelectorEvent"] = "SelectorEvent"

def to_text(self) -> str:
return str(self.content)


class MessageFactory:
""":meta private:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
ModelFamily,
SystemMessage,
UserMessage,
)
from pydantic import BaseModel
from typing_extensions import Self

Expand All @@ -16,6 +23,8 @@
BaseAgentEvent,
BaseChatMessage,
MessageFactory,
ModelClientStreamingChunkEvent,
SelectorEvent,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
Expand Down Expand Up @@ -56,6 +65,7 @@ def __init__(
max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType],
emit_team_events: bool,
model_client_streaming: bool = False,
) -> None:
super().__init__(
name,
Expand All @@ -79,6 +89,7 @@ def __init__(
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
self._model_client_streaming = model_client_streaming

async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
Expand Down Expand Up @@ -194,7 +205,37 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st
num_attempts = 0
while num_attempts < max_attempts:
num_attempts += 1
response = await self._model_client.create(messages=select_speaker_messages)
if self._model_client_streaming:
message: CreateResult | str = ""
async for _message in self._model_client.create_stream(messages=select_speaker_messages):
message = _message
Comment thread
ekzhu marked this conversation as resolved.
Outdated
if self._emit_team_events:
if isinstance(message, str):
await self._output_message_queue.put(
ModelClientStreamingChunkEvent(content=cast(str, _message), source=self._name)
)
else:
if isinstance(message, CreateResult):
response = message
else:
raise ValueError("Model failed to select a speaker.")

if isinstance(response.content, str):
if self._emit_team_events:
await self._output_message_queue.put(
SelectorEvent(content=response.content, source=self._name)
)
else:
response.content = "" # fallback to empty string
if self._emit_team_events:
await self._output_message_queue.put(
SelectorEvent(
content="Model failed to select a valid content type(it must str)",
source=self._name,
)
)
else:
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
Expand Down Expand Up @@ -281,6 +322,7 @@ class SelectorGroupChatConfig(BaseModel):
# selector_func: ComponentModel | None
max_selector_attempts: int = 3
emit_team_events: bool = False
model_client_streaming: bool = False


class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
Expand Down Expand Up @@ -311,6 +353,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
model_client_streaming (bool, optional): Whether to use streaming for the model.(Only use for specify case e.g. QwQ) Defaults to False.
Comment thread
ekzhu marked this conversation as resolved.
Outdated

Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
Expand Down Expand Up @@ -453,6 +496,7 @@ def __init__(
candidate_func: Optional[CandidateFuncType] = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
model_client_streaming: bool = False,
):
super().__init__(
participants,
Expand All @@ -473,6 +517,7 @@ def __init__(
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._model_client_streaming = model_client_streaming

def _create_group_chat_manager_factory(
self,
Expand Down Expand Up @@ -505,6 +550,7 @@ def _create_group_chat_manager_factory(
self._max_selector_attempts,
self._candidate_func,
self._emit_team_events,
self._model_client_streaming,
)

def _to_config(self) -> SelectorGroupChatConfig:
Expand All @@ -518,6 +564,7 @@ def _to_config(self) -> SelectorGroupChatConfig:
max_selector_attempts=self._max_selector_attempts,
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
emit_team_events=self._emit_team_events,
model_client_streaming=self._model_client_streaming,
)

@classmethod
Expand All @@ -536,4 +583,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
# if config.selector_func
# else None,
emit_team_events=config.emit_team_events,
model_client_streaming=config.model_client_streaming,
)
59 changes: 59 additions & 0 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BaseChatMessage,
HandoffMessage,
MultiModalMessage,
SelectorEvent,
SelectSpeakerEvent,
StopMessage,
StructuredMessage,
Expand Down Expand Up @@ -1698,3 +1699,61 @@ async def test_structured_message_state_roundtrip(runtime: AgentRuntime | None)
)

assert manager1._message_thread == manager2._message_thread # pyright: ignore


@pytest.mark.asyncio
async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient(
["agent3", "agent2", "agent1", "agent2", "agent1"],
)
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2)
agent2 = _EchoAgent("agent2", description="echo agent 2")
agent3 = _EchoAgent("agent3", description="echo agent 3")
termination = TextMentionTermination("TERMINATE")
team = SelectorGroupChat(
participants=[agent1, agent2, agent3],
model_client=model_client,
termination_condition=termination,
runtime=runtime,
emit_team_events=True,
model_client_streaming=True,
)
result = await team.run(
task="Write a program that prints 'Hello, world!'",
)

assert len(result.messages) == 16
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], SelectorEvent)
assert isinstance(result.messages[2], SelectSpeakerEvent)
assert isinstance(result.messages[3], TextMessage)
assert isinstance(result.messages[4], SelectorEvent)
assert isinstance(result.messages[5], SelectSpeakerEvent)
assert isinstance(result.messages[6], TextMessage)
assert isinstance(result.messages[7], SelectorEvent)
assert isinstance(result.messages[8], SelectSpeakerEvent)
assert isinstance(result.messages[9], TextMessage)
assert isinstance(result.messages[10], SelectorEvent)
assert isinstance(result.messages[11], SelectSpeakerEvent)
assert isinstance(result.messages[12], TextMessage)
assert isinstance(result.messages[13], SelectorEvent)
assert isinstance(result.messages[14], SelectSpeakerEvent)
assert isinstance(result.messages[15], StopMessage)

assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].content == "agent3"
assert result.messages[2].content == ["agent3"]
assert result.messages[3].source == "agent3"
assert result.messages[4].content == "agent2"
assert result.messages[5].content == ["agent2"]
assert result.messages[6].source == "agent2"
assert result.messages[7].content == "agent1"
assert result.messages[8].content == ["agent1"]
assert result.messages[9].source == "agent1"
assert result.messages[10].content == "agent2"
assert result.messages[11].content == ["agent2"]
assert result.messages[12].source == "agent2"
assert result.messages[13].content == "agent1"
assert result.messages[14].content == ["agent1"]
assert result.messages[15].source == "agent1"
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"