Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Get Nested Agents in a GroupChat #1636

Merged
merged 18 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autogen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .version import __version__
from .oai import *
from .agentchat import *
from .exception_utils import *
from .code_utils import DEFAULT_MODEL, FAST_MODEL


Expand Down
35 changes: 28 additions & 7 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


from ..code_utils import content_str
from ..exception_utils import AgentNameConflict
from .agent import Agent
from .conversable_agent import ConversableAgent
from ..runtime_logging import logging_enabled, log_new_agent
Expand Down Expand Up @@ -174,9 +175,26 @@ def append(self, message: Dict, speaker: Agent):
message["content"] = content_str(message["content"])
self.messages.append(message)

def agent_by_name(self, name: str) -> Agent:
"""Returns the agent with a given name."""
return self.agents[self.agent_names.index(name)]
def agent_by_name(
WaelKarkoub marked this conversation as resolved.
Show resolved Hide resolved
self, name: str, recursive: bool = False, raise_on_name_conflict: bool = False
) -> Optional[Agent]:
"""Returns the agent with a given name. If recursive is True, it will search in nested teams."""
agents = self.nested_agents() if recursive else self.agents
filtered_agents = [agent for agent in agents if agent.name == name]

if raise_on_name_conflict and len(filtered_agents) > 1:
raise AgentNameConflict()

return filtered_agents[0] if filtered_agents else None

def nested_agents(self) -> List[Agent]:
"""Returns all agents in the group chat manager."""
agents = self.agents.copy()
for agent in agents:
if isinstance(agent, GroupChatManager):
# Recursive call for nested teams
agents.extend(agent.groupchat.nested_agents())
return agents

def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agent:
"""Return the next agent in the list."""
Expand Down Expand Up @@ -390,10 +408,8 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents:
)

# Return the result
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker, agents)
agent = self.agent_by_name(name)
return agent if agent else self.next_agent(last_speaker, agents)

def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
Expand Down Expand Up @@ -480,6 +496,11 @@ def __init__(
ignore_async_in_sync_chat=True,
)

@property
def groupchat(self) -> GroupChat:
"""Returns the group chat managed by the group chat manager."""
return self._groupchat

def chat_messages_for_summary(self, agent: Agent) -> List[Dict]:
"""The list of messages in the group chat as a conversation to summarize.
The agent is ignored.
Expand Down
3 changes: 3 additions & 0 deletions autogen/exception_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class AgentNameConflict(Exception):
def __init__(self, msg="Found multiple agents with the same name.", *args, **kwargs):
super().__init__(msg, *args, **kwargs)
132 changes: 132 additions & 0 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict, List, Optional, Type
from autogen import AgentNameConflict
import pytest
from unittest import mock
import builtins
Expand Down Expand Up @@ -672,6 +674,136 @@ def test_clear_agents_history():
]


def test_get_agent_by_name():
def agent(name: str) -> autogen.ConversableAgent:
return autogen.ConversableAgent(
name=name,
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
)

def team(members: List[autogen.Agent], name: str) -> autogen.Agent:
gc = autogen.GroupChat(agents=members, messages=[])

return autogen.GroupChatManager(groupchat=gc, name=name, llm_config=False)

team_member1 = agent("team1_member1")
team_member2 = agent("team1_member2")
team_dup_member1 = agent("team1_member1")
team_dup_member2 = agent("team1_member2")

user = agent("user")
team1 = team([team_member1, team_member2], "team1")
team1_duplicate = team([team_dup_member1, team_dup_member2], "team1")

gc = autogen.GroupChat(agents=[user, team1, team1_duplicate], messages=[])

# Testing default arguments
assert gc.agent_by_name("user") == user
assert gc.agent_by_name("team1") == team1 or gc.agent_by_name("team1") == team1_duplicate

# Testing recursive search
assert gc.agent_by_name("user", recursive=True) == user
assert (
gc.agent_by_name("team1_member1", recursive=True) == team_member1
or gc.agent_by_name("team1_member1", recursive=True) == team_dup_member1
)

# Get agent that does not exist
assert gc.agent_by_name("team2") is None
assert gc.agent_by_name("team2", recursive=True) is None
assert gc.agent_by_name("team2", raise_on_name_conflict=True) is None
assert gc.agent_by_name("team2", recursive=True, raise_on_name_conflict=True) is None

# Testing naming conflict
with pytest.raises(AgentNameConflict):
gc.agent_by_name("team1", raise_on_name_conflict=True)

# Testing name conflict with recursive search
with pytest.raises(AgentNameConflict):
gc.agent_by_name("team1_member1", recursive=True, raise_on_name_conflict=True)


def test_get_nested_agents_in_groupchat():
def agent(name: str) -> autogen.ConversableAgent:
return autogen.ConversableAgent(
name=name,
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
)

def team(name: str) -> autogen.ConversableAgent:
member1 = agent(f"member1_{name}")
member2 = agent(f"member2_{name}")

gc = autogen.GroupChat(agents=[member1, member2], messages=[])

return autogen.GroupChatManager(groupchat=gc, name=name, llm_config=False)

user = agent("user")
team1 = team("team1")
team2 = team("team2")
WaelKarkoub marked this conversation as resolved.
Show resolved Hide resolved

gc = autogen.GroupChat(agents=[user, team1, team2], messages=[])

agents = gc.nested_agents()
assert len(agents) == 7


def test_nested_teams_chat():
"""Tests chat capabilities of nested teams"""
team1_msg = {"content": "Hello from team 1"}
team2_msg = {"content": "Hello from team 2"}

def agent(name: str, auto_reply: Optional[Dict[str, Any]] = None) -> autogen.ConversableAgent:
return autogen.ConversableAgent(
name=name,
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply=auto_reply,
)

def team(name: str, auto_reply: Optional[Dict[str, Any]] = None) -> autogen.ConversableAgent:
member1 = agent(f"member1_{name}", auto_reply=auto_reply)
member2 = agent(f"member2_{name}", auto_reply=auto_reply)

gc = autogen.GroupChat(agents=[member1, member2], messages=[])

return autogen.GroupChatManager(groupchat=gc, name=name, llm_config=False)

def chat(gc_manager: autogen.GroupChatManager):
team1_member1 = gc_manager.groupchat.agent_by_name("member1_team1", recursive=True)
team2_member2 = gc_manager.groupchat.agent_by_name("member2_team2", recursive=True)

assert team1_member1 is not None
assert team2_member2 is not None

team1_member1.send(team1_msg, team2_member2, request_reply=True)

user = agent("user")
team1 = team("team1", auto_reply=team1_msg)
team2 = team("team2", auto_reply=team2_msg)

gc = autogen.GroupChat(agents=[user, team1, team2], messages=[])
gc_manager = autogen.GroupChatManager(groupchat=gc, llm_config=False)

chat(gc_manager)

team1_member1 = gc.agent_by_name("member1_team1", recursive=True)
team2_member2 = gc.agent_by_name("member2_team2", recursive=True)

assert team1_member1 and team2_member2

msg = team1_member1.chat_messages[team2_member2][0]
reply = team1_member1.chat_messages[team2_member2][1]

assert msg["content"] == team1_msg["content"]
assert reply["content"] == team2_msg["content"]


if __name__ == "__main__":
# test_func_call_groupchat()
# test_broadcast()
Expand Down
Loading