Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update speaker selector in GroupChat and update some notebooks #688

Merged
merged 27 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b6f528a
Add speaker selection methods
thinkall Nov 15, 2023
c558e68
Update groupchat RAG
thinkall Nov 15, 2023
8953173
Update seed to cache_seed
thinkall Nov 15, 2023
8cae4f4
Update RetrieveChat notebook
thinkall Nov 15, 2023
37a6f74
Update parameter name
thinkall Nov 15, 2023
9008d91
Add test
thinkall Nov 15, 2023
a660718
Add more tests
thinkall Nov 15, 2023
de073c8
Add mock to test
thinkall Nov 15, 2023
91d50ca
Add mock to test
thinkall Nov 15, 2023
5f24a0e
Merge branch 'main' into update_selector
thinkall Nov 16, 2023
da7c932
Fix typo speaking
thinkall Nov 16, 2023
4872cff
Add gracefully exit manual input
thinkall Nov 16, 2023
07e4fdb
Update round_robin docstring
thinkall Nov 16, 2023
28b1711
Add method checking
thinkall Nov 16, 2023
10255de
Merge main
thinkall Nov 17, 2023
7966c27
Remove participant roles
thinkall Nov 17, 2023
16f1681
Fix versions in notebooks
thinkall Nov 17, 2023
3dbbd33
Minimize installation overhead
thinkall Nov 17, 2023
9955a60
Fix missing lower()
thinkall Nov 17, 2023
e7c3cc3
Add comments for try_count 3
thinkall Nov 17, 2023
38cb8ae
Update warning for n_agents < 3
thinkall Nov 17, 2023
8c9cb59
Update warning for n_agents < 3
thinkall Nov 17, 2023
7449ca5
Add test_n_agents_less_than_3
thinkall Nov 17, 2023
8817302
Add a function for manual select
thinkall Nov 17, 2023
68164e8
Merge main
thinkall Nov 17, 2023
f28eca5
Update version in notebooks
thinkall Nov 17, 2023
0e230b6
Fixed bugs that allow speakers to go twice in a row even when allow_r…
afourney Nov 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
pip install -e. pytest
pip install -e. pytest mock
pip uninstall -y openai
- name: Install unstructured if not windows
if: matrix.os != 'windows-2019'
Expand Down
101 changes: 86 additions & 15 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import re
Expand All @@ -21,13 +22,24 @@ class GroupChat:
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
- speaker_selection_method: the method for selecting the next speaker. Default is "auto".
Could be any of the following (case insensitive), will raise ValueError if not recognized:
- "auto": the next speaker is selected automatically by LLM.
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True.
"""

agents: List[Agent]
messages: List[Dict]
max_round: int = 10
admin_name: str = "Admin"
func_call_filter: bool = True
speaker_selection_method: str = "auto"
allow_repeat_speaker: bool = True

_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]

@property
def agent_names(self) -> List[str]:
Expand Down Expand Up @@ -55,13 +67,61 @@ def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
def select_speaker_msg(self, agents: List[Agent]):
"""Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}.
{self._participant_roles(agents)}.

Read the following conversation.
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""

def manual_select_speaker(self, agents: List[Agent]) -> Agent:
"""Manually select the next speaker."""

print("Please select the next speaker from the following list:")
_n_agents = len(agents)
for i in range(_n_agents):
print(f"{i+1}: {agents[i].name}")
try_count = 0
# Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking.
while try_count <= 3:
try_count += 1
if try_count >= 3:
print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
break
try:
i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ")
if i == "" or i == "q":
break
i = int(i)
if i > 0 and i <= _n_agents:
return agents[i - 1]
else:
raise ValueError
except ValueError:
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None

def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
thinkall marked this conversation as resolved.
Show resolved Hide resolved
)

agents = self.agents
n_agents = len(agents)
# Warn if GroupChat is underpopulated
if n_agents < 2:
raise ValueError(
f"GroupChat is underpopulated with {n_agents} agents. "
"Please add more agents to the GroupChat or use direct communication instead."
)
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and self.allow_repeat_speaker:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. "
"It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False."
"Or, use direct communication instead."
)

if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
# find agents with the right function_map which contains the function name
agents = [
Expand All @@ -80,14 +140,20 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
f"No agent can execute the function {self.messages[-1]['name']}. "
"Please check the function_map of the agents."
)
else:
agents = self.agents
# Warn if GroupChat is underpopulated
n_agents = len(agents)
if n_agents < 3:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
)

# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]

if self.speaker_selection_method.lower() == "manual":
selected_agent = self.manual_select_speaker(agents)
if selected_agent:
return selected_agent
elif self.speaker_selection_method.lower() == "round_robin":
return self.next_agent(last_speaker, agents)
elif self.speaker_selection_method.lower() == "random":
return random.choice(agents)

# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
final, name = selector.generate_oai_reply(
self.messages
Expand All @@ -99,26 +165,31 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
]
)
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents)

# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1:
name = next(iter(mentions))
else:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
)

# Return the result
try:
return self.agent_by_name(name)
except ValueError:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. Speaker selection will default to the next speaker in the list. This is because the speaker selection OAI call returned:\n{name}"
)
return self.next_agent(last_speaker, agents)

def _participant_roles(self):
def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
agents = self.agents

roles = []
for agent in self.agents:
for agent in agents:
if agent.system_message.strip() == "":
logger.warning(
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."
Expand Down
Loading
Loading