Skip to content

Commit

Permalink
From #753
Browse files Browse the repository at this point in the history
  • Loading branch information
joshkyh committed Dec 4, 2023
1 parent d010679 commit 747d0f0
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/contrib-openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,4 @@ jobs:
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
flags: unittests
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Age
return next_speaker
else:
# Cannot return next_speaker with no eligible speakers
raise ValueError("No eligible speakers found based on the graph constraints.")
raise ValueError("No eligible speakers found based on the graph constraints.")
73 changes: 44 additions & 29 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class GroupChat:
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- "graph": the next speaker is selected based on a graph. The select_speaker method is overridden by GraphGroupChat.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True.
"""

Expand All @@ -41,7 +42,7 @@ class GroupChat:
speaker_selection_method: str = "auto"
allow_repeat_speaker: bool = True

_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin", "graph"]

@property
def agent_names(self) -> List[str]:
Expand Down Expand Up @@ -109,6 +110,40 @@ def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]:
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None

def auto_select_speaker(
self, agents: List[Agent], last_speaker: Agent, selector: ConversableAgent
) -> (Agent, List[Agent], Agent, ConversableAgent):
# Encapsulating select_speaker_auto as a class method, so that it can be reused through inheritance in GraphGroupChat
# It returns the selected_agent, agents, last_speaker, and selector so as to preserve the states of the inputs from select_speaker
selector.update_system_message(self.select_speaker_msg(agents))
final, name = selector.generate_oai_reply(
self.messages
+ [
{
"role": "system",
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
}
]
)
if not final:
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents), agents, last_speaker, selector

# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1:
name = next(iter(mentions))
else:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
)

# Return the result
try:
return self.agent_by_name(name), agents, last_speaker, selector
except ValueError:
return self.next_agent(last_speaker, agents), agents, last_speaker, selector

def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
Expand Down Expand Up @@ -162,36 +197,16 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
return self.next_agent(last_speaker, agents)
elif self.speaker_selection_method.lower() == "random":
return random.choice(agents)

# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
final, name = selector.generate_oai_reply(
self.messages
+ [
{
"role": "system",
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
}
]
)
if not final:
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents)

# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1:
name = next(iter(mentions))
else:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
elif self.speaker_selection_method.lower() == "graph":
# This should not trigger because GraphGroupChat select_speaker overrides GroupChat select_speaker
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
f"GraphGroupChat select_speaker overrides GroupChat select_speaker. "
)

# Return the result
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker, agents)
# Since it is not manual nor round_robin nor random, it must be auto
auto_selected_speaker, agents, last_speaker, selector = self.auto_select_speaker(agents, last_speaker, selector)
return auto_selected_speaker

def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"retrievechat": ["chromadb", "sentence_transformers", "pypdf", "ipython"],
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
"graphs": ["networkx~=3.2.1", "matplotlib~=3.8.1"],
"graphs": ["networkx", "matplotlib"],
},
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
252 changes: 252 additions & 0 deletions test/agentchat/contrib/test_graphgroupchat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
try:
import networkx as nx
import matplotlib.pyplot as plt
from autogen.agentchat.contrib.graphgroupchat import GraphGroupChat

skip_test = False
except (ModuleNotFoundError, ImportError):
skip_test = False


import autogen
from io import StringIO
import logging

import sys
import os
import pytest
from unittest.mock import MagicMock
import unittest
from autogen.agentchat.groupchat import GroupChat, Agent, ConversableAgent
from autogen.agentchat.assistant_agent import AssistantAgent


sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402


#config_list = autogen.config_list_from_json(
# OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"api_type": ["openai"]}
#)

config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, filter_dict={"model": ["dev-oai-gpt4"]})

# assert len(config_list) > 0


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
class TestGraphGroupChatGraphValidity(unittest.TestCase):
def test_graph_with_no_nodes(self):
agents = [Agent(name="Agent 1"), Agent(name="Agent 2")]
messages = []
graph = nx.DiGraph()

with pytest.raises(ValueError) as excinfo:
GraphGroupChat(agents, messages, graph)
assert "The graph has no nodes." in str(excinfo.value)

def test_graph_with_no_edges(self):
agents = [Agent(name="Agent 1"), Agent(name="Agent 2")]
messages = []
graph = nx.DiGraph()
graph.add_node("Agent 1")

with pytest.raises(ValueError) as excinfo:
GraphGroupChat(agents, messages, graph)
assert "The graph has no edges." in str(excinfo.value)

def test_graph_with_no_first_round_speaker(self):
agents = [Agent(name="Agent 1"), Agent(name="Agent 2")]
messages = []
graph = nx.DiGraph()
graph.add_node("Agent 1")
graph.add_edge("Agent 1", "Agent 2")

with pytest.raises(ValueError) as excinfo:
GraphGroupChat(agents, messages, graph)
assert "The graph has no nodes with 'first_round_speaker' set to True." in str(excinfo.value)

def test_graph_with_self_loops(self):
agents = [Agent(name="Agent 1"), Agent(name="Agent 2")]
messages = []
graph = nx.DiGraph()
graph.add_node("Agent 1", first_round_speaker=True)
graph.add_node("Agent 2", first_round_speaker=False)
graph.add_edge("Agent 1", "Agent 1")

with pytest.raises(ValueError) as excinfo:
GraphGroupChat(agents, messages, graph, allow_repeat_speaker=False)
assert "The graph has self-loops, but self.allow_repeat_speaker is False." in str(excinfo.value)

def test_warning_isolated_agents(self):
# Setup a graph with isolated nodes
graph = nx.DiGraph()
graph.add_node("Agent1", first_round_speaker=True)
graph.add_node("Agent2")
graph.add_edge("Agent1", "Agent3") # Agent2 is isolated

# Setup agents
agents = [Agent(name="Agent1"), Agent(name="Agent2"), Agent(name="Agent3")]

# Redirect logging output to capture it for the test
log_capture_string = StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.WARNING)
logging.getLogger().addHandler(ch)

# Create GraphGroupChat instance
GraphGroupChat(agents=agents, messages=[], graph=graph)

# Check if warning is logged
log_contents = log_capture_string.getvalue()
self.assertIn("isolated agents", log_contents)

def test_warning_agents_not_in_graph(self):
# Setup a graph without all agents
graph = nx.DiGraph()
graph.add_node("Agent1", first_round_speaker=True)
graph.add_node("Agent2")
# Note: Agent3 is missing from the graph
graph.add_edge("Agent1", "Agent2") # Agent2 is isolated

# Setup agents including one not in the graph
agents = [Agent(name="Agent1"), Agent(name="Agent2"), Agent(name="Agent3")]

# Redirect logging output to capture it for the test
log_capture_string = StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.WARNING)
logging.getLogger().addHandler(ch)

# Create GraphGroupChat instance
GraphGroupChat(agents=agents, messages=[], graph=graph)

# Check if warning is logged
log_contents = log_capture_string.getvalue()
self.assertIn("agents not in self.agents", log_contents)


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
class TestGraphGroupChatSelectSpeakerThreeAssistantAgents:
@pytest.fixture(autouse=True)
def setup(self):
# The default config list in notebook.
llm_config = {"config_list": config_list, "cache_seed": 100}

# Mock Agents
self.agent1 = AssistantAgent(name="alice", llm_config=llm_config)
self.agent2 = AssistantAgent(name="bob", llm_config=llm_config)
self.agent3 = AssistantAgent(name="charlie", llm_config=llm_config)

# Mock ConversableAgent Selector
self.selector = ConversableAgent(name="selector", llm_config=llm_config)
self.selector.generate_oai_reply = MagicMock(return_value=(True, "bob"))

# Create Graph
self.graph = nx.DiGraph()
self.graph.add_node(self.agent1.name, first_round_speaker=True)
self.graph.add_node(self.agent2.name)
self.graph.add_node(self.agent3.name)
self.graph.add_edge(self.agent1.name, self.agent2.name)
self.graph.add_edge(self.agent1.name, self.agent3.name)
self.graph.add_edge(self.agent2.name, self.agent3.name)
self.graph.add_edge(self.agent2.name, self.agent1.name)
self.graph.add_edge(self.agent3.name, self.agent1.name)
self.graph.add_edge(self.agent3.name, self.agent2.name)

# Create agents
self.agents = [self.agent1, self.agent2, self.agent3]

def test_select_first_round_speaker(self):
chat = GraphGroupChat(self.agents, [], self.graph)
selected_speaker = chat.select_speaker(last_speaker=None, selector=self.selector)
assert selected_speaker.name == "alice"

def test_using_suggested_next_speaker(self):
chat = GraphGroupChat(self.agents, [{"content": "Some message. NEXT: charlie"}], self.graph)
selected_speaker = chat.select_speaker(last_speaker=self.agent1, selector=self.selector)
assert selected_speaker.name == "charlie"

def test_using_llm_to_pick_speaker(self):
chat = GraphGroupChat(self.agents, [{"content": "Some message."}], self.graph)
selected_speaker = chat.select_speaker(last_speaker=self.agent1, selector=self.selector)
assert selected_speaker.name in ("bob", "charlie")

def test_random_speaker_selection(self):
chat = GraphGroupChat(self.agents, [], self.graph, allow_repeat_speaker=False)
chat.previous_speaker = self.agent3
# Overridde random.choice to always return agent2
with unittest.mock.patch("random.choice", return_value=self.agent2):
selected_speaker = chat.select_speaker(last_speaker=self.agent3, selector=self.selector)
assert selected_speaker.name == "bob"


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
class TestGraphGroupChatSelectSpeakerOneAssistantAgentOneUserProxy:
@pytest.fixture(autouse=True)
def setup(self):
# The default config list in notebook.
self.llm_config = {"config_list": config_list, "cache_seed": 100}

# Mock Agents
self.agent1 = AssistantAgent(name="alice", llm_config=self.llm_config)

# Termination message detection
def is_termination_msg(content) -> bool:
have_content = content.get("content", None) is not None
if have_content and "TERMINATE" in content["content"]:
return True
return False

# Terminates the conversation when TERMINATE is detected.
self.user_proxy = autogen.UserProxyAgent(
name="User_proxy",
system_message="Terminator admin.",
code_execution_config=False,
is_termination_msg=is_termination_msg,
human_input_mode="NEVER",
)

self.agents = [self.user_proxy, self.agent1]

# Create Graph
self.graph = nx.DiGraph()

# Add nodes for all agents
for agent in self.agents:
self.graph.add_node(agent.name, first_round_speaker=True)
# Add edges between all agents
for agent1 in self.agents:
for agent2 in self.agents:
if agent1 != agent2:
self.graph.add_edge(agent1.name, agent2.name)

def test_interaction(self):
graph_group_chat = GraphGroupChat(
agents=self.agents, messages=[], max_round=20, graph=self.graph # Include all agents
)

# Create the manager
manager = autogen.GroupChatManager(groupchat=graph_group_chat, llm_config=self.llm_config)

# Initiates the chat with Alice
self.agents[0].initiate_chat(
manager,
message="""Ask alice what is the largest single digit prime number.""",
)

# Assert the messages contain 7
# don't just check the last message
assert any("7" in message["content"] for message in graph_group_chat.messages)

# Assert the messages contain alice
assert any("alice" in message["name"] for message in graph_group_chat.messages)
Loading

0 comments on commit 747d0f0

Please sign in to comment.