Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to pass in a customized speaker selection method #1791

Merged
merged 30 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e89ec5e
init PR
yiranwu0 Feb 26, 2024
ac7330c
Merge branch 'main' into dcheck
yiranwu0 Feb 26, 2024
17f3e49
Merge remote-tracking branch 'origin/main' into dcheck
yiranwu0 Feb 29, 2024
405c500
update
yiranwu0 Feb 29, 2024
9f81d5e
Merge branch 'main' into dcheck
yiranwu0 Feb 29, 2024
6786cc2
update code check
yiranwu0 Feb 29, 2024
75f80a9
Merge branch 'main' into dcheck
yiranwu0 Feb 29, 2024
81e9ff0
update
yiranwu0 Feb 29, 2024
ff0a8de
update
yiranwu0 Feb 29, 2024
8fed3f2
Merge remote-tracking branch 'origin/main' into dcheck
yiranwu0 Feb 29, 2024
ae6c167
update
yiranwu0 Feb 29, 2024
b259b84
update
yiranwu0 Feb 29, 2024
251cba6
Test the ability to have agents a,u,t,o,g,e,n speak in turn.
joshkyh Mar 1, 2024
c90d2a6
Merge remote-tracking branch 'origin/main' into dcheck
yiranwu0 Mar 2, 2024
0dbc08a
update
yiranwu0 Mar 2, 2024
22b24d8
update
yiranwu0 Mar 2, 2024
7f18e83
update
yiranwu0 Mar 2, 2024
8477c61
Evidence that groupchat not terminating because of the TERMINATE subs…
joshkyh Mar 2, 2024
4bfb0a0
Raising NoEligibleSpeakerException allows graceful exit before max turns
joshkyh Mar 2, 2024
bdd09db
Merge remote-tracking branch 'origin/main' into dcheck
yiranwu0 Mar 2, 2024
9ac0351
update
yiranwu0 Mar 2, 2024
908dea9
Merge branch 'main' into dcheck
yiranwu0 Mar 2, 2024
899a09e
Merge branch 'main' into dcheck
yiranwu0 Mar 2, 2024
c4ab4a2
To confirm with author that custom function is meant to override grap…
joshkyh Mar 3, 2024
a0dac5c
Merge branch 'main' into dcheck
yiranwu0 Mar 3, 2024
3cb43e9
Confirmed the expected test behaviour with author
joshkyh Mar 4, 2024
a098937
Update autogen/agentchat/groupchat.py
qingyun-wu Mar 5, 2024
4d327ee
Merge remote-tracking branch 'origin/main' into dcheck
yiranwu0 Mar 6, 2024
b4dd24d
update
yiranwu0 Mar 6, 2024
3d729b4
update
yiranwu0 Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Tuple
from typing import Dict, List, Optional, Union, Tuple, Callable


from ..code_utils import content_str
Expand Down Expand Up @@ -42,7 +42,16 @@ class GroupChat:
- "manual": the next speaker is selected manually by user input.
- "random": the next speaker is selected randomly.
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.

- a customized speaker selection function (Callable): the function will be called to select the next speaker.
The function should take the last speaker and the group chat as input and return one of the following:
1. an `Agent` class, it must be one of the agents in the group chat.
2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use.
qingyun-wu marked this conversation as resolved.
Show resolved Hide resolved
qingyun-wu marked this conversation as resolved.
Show resolved Hide resolved
3. None, which would terminate the conversation gracefully.
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
```python
def custom_speaker_selection_func(
last_speaker: Agent, groupchat: GroupChat
) -> Union[Agent, str, None]:
```
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
Default is True, in which case all speakers are allowed to speak consecutively.
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
Expand All @@ -67,7 +76,7 @@ class GroupChat:
max_round: Optional[int] = 10
admin_name: Optional[str] = "Admin"
func_call_filter: Optional[bool] = True
speaker_selection_method: Optional[str] = "auto"
speaker_selection_method: Optional[Union[str, Callable]] = "auto"
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Optional[str] = None
Expand Down Expand Up @@ -277,11 +286,36 @@ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A
return random.choice(agents)

def _prepare_and_select_agents(
self, last_speaker: Agent
self,
last_speaker: Agent,
) -> Tuple[Optional[Agent], List[Agent], Optional[List[Dict]]]:
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
# If self.speaker_selection_method is a callable, call it to get the next speaker.
# If self.speaker_selection_method is a string, return it.
speaker_selection_method = self.speaker_selection_method
if isinstance(self.speaker_selection_method, Callable):
selected_agent = self.speaker_selection_method(last_speaker, self)
if selected_agent is None:
raise NoEligibleSpeakerException(
"Custom speaker selection function returned None. Terminating conversation."
)
elif isinstance(selected_agent, Agent):
if selected_agent in self.agents:
return selected_agent, self.agents, None
else:
raise ValueError(
f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat."
)
elif isinstance(selected_agent, str):
# If returned a string, assume it is a speaker selection method
speaker_selection_method = selected_agent
else:
raise ValueError(
f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str."
)

if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. "
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
)

Expand All @@ -300,7 +334,7 @@ def _prepare_and_select_agents(
f"GroupChat is underpopulated with {n_agents} agents. "
"Please add more agents to the GroupChat or use direct communication instead."
)
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. "
"Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, "
Expand Down Expand Up @@ -366,11 +400,11 @@ def _prepare_and_select_agents(

# Use the selected speaker selection method
select_speaker_messages = None
if self.speaker_selection_method.lower() == "manual":
if speaker_selection_method.lower() == "manual":
selected_agent = self.manual_select_speaker(graph_eligible_agents)
elif self.speaker_selection_method.lower() == "round_robin":
elif speaker_selection_method.lower() == "round_robin":
selected_agent = self.next_agent(last_speaker, graph_eligible_agents)
elif self.speaker_selection_method.lower() == "random":
elif speaker_selection_method.lower() == "random":
selected_agent = self.random_select_speaker(graph_eligible_agents)
else:
selected_agent = None
Expand Down
1 change: 1 addition & 0 deletions notebook/agentchat_custom_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@
"source": [
"# load model here\n",
"\n",
"\n",
"config = config_list_custom[0]\n",
"device = config.get(\"device\", \"cpu\")\n",
"loaded_model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(device)\n",
Expand Down
Loading
Loading