Skip to content

Commit

Permalink
Update speaker selector in GroupChat and update some notebooks (#688)
Browse files Browse the repository at this point in the history
* Add speaker selection methods

* Update groupchat RAG

* Update seed to cache_seed

* Update RetrieveChat notebook

* Update parameter name

* Add test

* Add more tests

* Add mock to test

* Add mock to test

* Fix typo speaking

* Add gracefully exit manual input

* Update round_robin docstring

* Add method checking

* Remove participant roles

* Fix versions in notebooks

* Minimize installation overhead

* Fix missing lower()

* Add comments for try_count 3

* Update warning for n_agents < 3

* Update warning for n_agents < 3

* Add test_n_agents_less_than_3

* Add a function for manual select

* Update version in notebooks

* Fixed bugs that allow speakers to go twice in a row even when allow_repeat_speaker = False

---------

Co-authored-by: Adam Fourney <[email protected]>
  • Loading branch information
thinkall and afourney authored Nov 17, 2023
1 parent 3ab8c97 commit 370ebf5
Show file tree
Hide file tree
Showing 8 changed files with 1,842 additions and 3,530 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
pip install -e. pytest
pip install -e. pytest mock
pip uninstall -y openai
- name: Install unstructured if not windows
if: matrix.os != 'windows-2019'
Expand Down
101 changes: 86 additions & 15 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import re
Expand All @@ -21,13 +22,24 @@ class GroupChat:
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
- speaker_selection_method: the method for selecting the next speaker. Default is "auto".
Could be any of the following (case insensitive), will raise ValueError if not recognized:
- "auto": the next speaker is selected automatically by LLM.
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True.
"""

agents: List[Agent]
messages: List[Dict]
max_round: int = 10
admin_name: str = "Admin"
func_call_filter: bool = True
speaker_selection_method: str = "auto"
allow_repeat_speaker: bool = True

_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]

@property
def agent_names(self) -> List[str]:
Expand Down Expand Up @@ -55,13 +67,61 @@ def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
def select_speaker_msg(self, agents: List[Agent]):
"""Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}.
{self._participant_roles(agents)}.
Read the following conversation.
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""

def manual_select_speaker(self, agents: List[Agent]) -> Agent:
"""Manually select the next speaker."""

print("Please select the next speaker from the following list:")
_n_agents = len(agents)
for i in range(_n_agents):
print(f"{i+1}: {agents[i].name}")
try_count = 0
# Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking.
while try_count <= 3:
try_count += 1
if try_count >= 3:
print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
break
try:
i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ")
if i == "" or i == "q":
break
i = int(i)
if i > 0 and i <= _n_agents:
return agents[i - 1]
else:
raise ValueError
except ValueError:
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None

def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
)

agents = self.agents
n_agents = len(agents)
# Warn if GroupChat is underpopulated
if n_agents < 2:
raise ValueError(
f"GroupChat is underpopulated with {n_agents} agents. "
"Please add more agents to the GroupChat or use direct communication instead."
)
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and self.allow_repeat_speaker:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. "
"It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False."
"Or, use direct communication instead."
)

if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
# find agents with the right function_map which contains the function name
agents = [
Expand All @@ -80,14 +140,20 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
f"No agent can execute the function {self.messages[-1]['name']}. "
"Please check the function_map of the agents."
)
else:
agents = self.agents
# Warn if GroupChat is underpopulated
n_agents = len(agents)
if n_agents < 3:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
)

# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]

if self.speaker_selection_method.lower() == "manual":
selected_agent = self.manual_select_speaker(agents)
if selected_agent:
return selected_agent
elif self.speaker_selection_method.lower() == "round_robin":
return self.next_agent(last_speaker, agents)
elif self.speaker_selection_method.lower() == "random":
return random.choice(agents)

# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
final, name = selector.generate_oai_reply(
self.messages
Expand All @@ -99,26 +165,31 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
]
)
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents)

# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1:
name = next(iter(mentions))
else:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
)

# Return the result
try:
return self.agent_by_name(name)
except ValueError:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. Speaker selection will default to the next speaker in the list. This is because the speaker selection OAI call returned:\n{name}"
)
return self.next_agent(last_speaker, agents)

def _participant_roles(self):
def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
agents = self.agents

roles = []
for agent in self.agents:
for agent in agents:
if agent.system_message.strip() == "":
logger.warning(
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."
Expand Down
Loading

0 comments on commit 370ebf5

Please sign in to comment.