Skip to content

Commit

Permalink
Partial fix for 960 (#963)
Browse files Browse the repository at this point in the history
* Partial fix for 960

* Fixed a missing = None

* Added test coverage.
  • Loading branch information
afourney authored Dec 24, 2023
1 parent 26b7aff commit b1adac5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
16 changes: 12 additions & 4 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def agent_by_name(self, name: str) -> Agent:
"""Returns the agent with a given name."""
return self.agents[self.agent_names.index(name)]

def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agent:
"""Return the next agent in the list."""
if agents is None:
agents = self.agents

# What index is the agent? (-1 if not present)
idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1
Expand All @@ -79,20 +81,26 @@ def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]

def select_speaker_msg(self, agents: List[Agent]) -> str:
def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the system message for selecting the next speaker. This is always the *first* message in the context."""
if agents is None:
agents = self.agents
return f"""You are in a role play game. The following roles are available:
{self._participant_roles(agents)}.
Read the following conversation.
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""

def select_speaker_prompt(self, agents: List[Agent]) -> str:
def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the floating system prompt selecting the next speaker. This is always the *last* message in the context."""
if agents is None:
agents = self.agents
return f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."

def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]:
def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
"""Manually select the next speaker."""
if agents is None:
agents = self.agents

print("Please select the next speaker from the following list:")
_n_agents = len(agents)
Expand Down
46 changes: 46 additions & 0 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ def test_next_agent():
assert groupchat.next_agent(agent2, [agent1, agent2, agent3]) == agent3
assert groupchat.next_agent(agent3, [agent1, agent2, agent3]) == agent1

assert groupchat.next_agent(agent1) == agent2
assert groupchat.next_agent(agent2) == agent3
assert groupchat.next_agent(agent3) == agent1

assert groupchat.next_agent(agent1, [agent1, agent3]) == agent3
assert groupchat.next_agent(agent3, [agent1, agent3]) == agent1

Expand All @@ -429,6 +433,48 @@ def test_next_agent():
assert groupchat.next_agent(agent4, [agent1, agent2, agent3]) == agent1


def test_selection_helpers():
agent1 = autogen.ConversableAgent(
"alice",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice speaking.",
description="Alice is an AI agent.",
)
agent2 = autogen.ConversableAgent(
"bob",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
description="Bob is an AI agent.",
)
agent3 = autogen.ConversableAgent(
"sam",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is sam speaking.",
system_message="Sam is an AI agent.",
)

# Test empty is_termination_msg function
groupchat = autogen.GroupChat(
agents=[agent1, agent2, agent3], messages=[], speaker_selection_method="round_robin", max_round=10
)

select_speaker_msg = groupchat.select_speaker_msg()
select_speaker_prompt = groupchat.select_speaker_prompt()

assert "Alice is an AI agent." in select_speaker_msg
assert "Bob is an AI agent." in select_speaker_msg
assert "Sam is an AI agent." in select_speaker_msg
assert str(["Alice", "Bob", "Sam"]).lower() in select_speaker_prompt.lower()

with mock.patch.object(builtins, "input", lambda _: "1"):
groupchat.manual_select_speaker()


if __name__ == "__main__":
# test_func_call_groupchat()
# test_broadcast()
Expand Down

0 comments on commit b1adac5

Please sign in to comment.