Skip to content

Commit 68d51d1

Browse files
thinkallafourney
andauthored
Update speaker selector in GroupChat and update some notebooks (microsoft#688)
* Add speaker selection methods * Update groupchat RAG * Update seed to cache_seed * Update RetrieveChat notebook * Update parameter name * Add test * Add more tests * Add mock to test * Add mock to test * Fix typo speaking * Add gracefully exit manual input * Update round_robin docstring * Add method checking * Remove participant roles * Fix versions in notebooks * Minimize installation overhead * Fix missing lower() * Add comments for try_count 3 * Update warning for n_agents < 3 * Update warning for n_agents < 3 * Add test_n_agents_less_than_3 * Add a function for manual select * Update version in notebooks * Fixed bugs that allow speakers to go twice in a row even when allow_repeat_speaker = False --------- Co-authored-by: Adam Fourney <[email protected]>
1 parent 399f505 commit 68d51d1

File tree

8 files changed

+1842
-3530
lines changed

8 files changed

+1842
-3530
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
python -m pip install --upgrade pip wheel
4141
pip install -e .
4242
python -c "import autogen"
43-
pip install -e. pytest
43+
pip install -e. pytest mock
4444
pip uninstall -y openai
4545
- name: Install unstructured if not windows
4646
if: matrix.os != 'windows-2019'

autogen/agentchat/groupchat.py

+86-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import sys
3+
import random
34
from dataclasses import dataclass
45
from typing import Dict, List, Optional, Union
56
import re
@@ -21,13 +22,24 @@ class GroupChat:
2122
When set to True and when a message is a function call suggestion,
2223
the next speaker will be chosen from an agent which contains the corresponding function name
2324
in its `function_map`.
25+
- speaker_selection_method: the method for selecting the next speaker. Default is "auto".
26+
Could be any of the following (case insensitive), will raise ValueError if not recognized:
27+
- "auto": the next speaker is selected automatically by LLM.
28+
- "manual": the next speaker is selected manually by user input.
29+
- "random": the next speaker is selected randomly.
30+
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
31+
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True.
2432
"""
2533

2634
agents: List[Agent]
2735
messages: List[Dict]
2836
max_round: int = 10
2937
admin_name: str = "Admin"
3038
func_call_filter: bool = True
39+
speaker_selection_method: str = "auto"
40+
allow_repeat_speaker: bool = True
41+
42+
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
3143

3244
@property
3345
def agent_names(self) -> List[str]:
@@ -55,13 +67,61 @@ def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
5567
def select_speaker_msg(self, agents: List[Agent]):
5668
"""Return the message for selecting the next speaker."""
5769
return f"""You are in a role play game. The following roles are available:
58-
{self._participant_roles()}.
70+
{self._participant_roles(agents)}.
5971
6072
Read the following conversation.
6173
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
6274

75+
def manual_select_speaker(self, agents: List[Agent]) -> Agent:
76+
"""Manually select the next speaker."""
77+
78+
print("Please select the next speaker from the following list:")
79+
_n_agents = len(agents)
80+
for i in range(_n_agents):
81+
print(f"{i+1}: {agents[i].name}")
82+
try_count = 0
83+
# Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking.
84+
while try_count <= 3:
85+
try_count += 1
86+
if try_count >= 3:
87+
print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
88+
break
89+
try:
90+
i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ")
91+
if i == "" or i == "q":
92+
break
93+
i = int(i)
94+
if i > 0 and i <= _n_agents:
95+
return agents[i - 1]
96+
else:
97+
raise ValueError
98+
except ValueError:
99+
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
100+
return None
101+
63102
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
64103
"""Select the next speaker."""
104+
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
105+
raise ValueError(
106+
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
107+
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
108+
)
109+
110+
agents = self.agents
111+
n_agents = len(agents)
112+
# Warn if GroupChat is underpopulated
113+
if n_agents < 2:
114+
raise ValueError(
115+
f"GroupChat is underpopulated with {n_agents} agents. "
116+
"Please add more agents to the GroupChat or use direct communication instead."
117+
)
118+
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and self.allow_repeat_speaker:
119+
logger.warning(
120+
f"GroupChat is underpopulated with {n_agents} agents. "
121+
"It is recommended to set speaker_selection_method to 'round_robin' or allow_repeat_speaker to False."
122+
"Or, use direct communication instead."
123+
)
124+
65125
if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
66126
# find agents with the right function_map which contains the function name
67127
agents = [
@@ -80,14 +140,20 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
80140
f"No agent can execute the function {self.messages[-1]['name']}. "
81141
"Please check the function_map of the agents."
82142
)
83-
else:
84-
agents = self.agents
85-
# Warn if GroupChat is underpopulated
86-
n_agents = len(agents)
87-
if n_agents < 3:
88-
logger.warning(
89-
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
90-
)
143+
144+
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
145+
agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]
146+
147+
if self.speaker_selection_method.lower() == "manual":
148+
selected_agent = self.manual_select_speaker(agents)
149+
if selected_agent:
150+
return selected_agent
151+
elif self.speaker_selection_method.lower() == "round_robin":
152+
return self.next_agent(last_speaker, agents)
153+
elif self.speaker_selection_method.lower() == "random":
154+
return random.choice(agents)
155+
156+
# auto speaker selection
91157
selector.update_system_message(self.select_speaker_msg(agents))
92158
final, name = selector.generate_oai_reply(
93159
self.messages
@@ -99,26 +165,31 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
99165
]
100166
)
101167
if not final:
102-
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
168+
# the LLM client is None, thus no reply is generated. Use round robin instead.
103169
return self.next_agent(last_speaker, agents)
104170

105171
# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
106172
mentions = self._mentioned_agents(name, agents)
107173
if len(mentions) == 1:
108174
name = next(iter(mentions))
175+
else:
176+
logger.warning(
177+
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
178+
)
109179

110180
# Return the result
111181
try:
112182
return self.agent_by_name(name)
113183
except ValueError:
114-
logger.warning(
115-
f"GroupChat select_speaker failed to resolve the next speaker's name. Speaker selection will default to the next speaker in the list. This is because the speaker selection OAI call returned:\n{name}"
116-
)
117184
return self.next_agent(last_speaker, agents)
118185

119-
def _participant_roles(self):
186+
def _participant_roles(self, agents: List[Agent] = None) -> str:
187+
# Default to all agents registered
188+
if agents is None:
189+
agents = self.agents
190+
120191
roles = []
121-
for agent in self.agents:
192+
for agent in agents:
122193
if agent.system_message.strip() == "":
123194
logger.warning(
124195
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."

0 commit comments

Comments
 (0)