Skip to content

Commit e9e4623

Browse files
authored
Merge branch 'main' into vdibia/organize_docs
2 parents e8afe97 + 8d4afe4 commit e9e4623

File tree

3 files changed

+117
-21
lines changed

3 files changed

+117
-21
lines changed

autogen/agentchat/conversable_agent.py

+9
Original file line numberDiff line numberDiff line change
@@ -1017,3 +1017,12 @@ def register_function(self, function_map: Dict[str, Callable]):
10171017
function_map: a dictionary mapping function names to functions.
10181018
"""
10191019
self._function_map.update(function_map)
1020+
1021+
def can_execute_function(self, name: str) -> bool:
1022+
"""Whether the agent can execute the function."""
1023+
return name in self._function_map
1024+
1025+
@property
1026+
def function_map(self) -> Dict[str, Callable]:
1027+
"""Return the function map."""
1028+
return self._function_map

autogen/agentchat/groupchat.py

+54-19
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,23 @@
1010

1111
@dataclass
1212
class GroupChat:
13-
"""A group chat class that contains a list of agents and the maximum number of rounds."""
13+
"""A group chat class that contains the following data fields:
14+
- agents: a list of participating agents.
15+
- messages: a list of messages in the group chat.
16+
- max_round: the maximum number of rounds.
17+
- admin_name: the name of the admin agent if there is one. Default is "Admin".
18+
KeyBoardInterrupt will make the admin agent take over.
19+
- func_call_filter: whether to enforce function call filter. Default is True.
20+
When set to True and when a message is a function call suggestion,
21+
the next speaker will be chosen from an agent which contains the corresponding function name
22+
in its `function_map`.
23+
"""
1424

1525
agents: List[Agent]
1626
messages: List[Dict]
1727
max_round: int = 10
18-
admin_name: str = "Admin" # the name of the admin agent
28+
admin_name: str = "Admin"
29+
func_call_filter: bool = True
1930

2031
@property
2132
def agent_names(self) -> List[str]:
@@ -30,45 +41,69 @@ def agent_by_name(self, name: str) -> Agent:
3041
"""Find the next speaker based on the message."""
3142
return self.agents[self.agent_names.index(name)]
3243

33-
def next_agent(self, agent: Agent) -> Agent:
44+
def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
3445
"""Return the next agent in the list."""
35-
return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]
36-
37-
def select_speaker_msg(self):
46+
if agents == self.agents:
47+
return agents[(self.agent_names.index(agent.name) + 1) % len(agents)]
48+
else:
49+
offset = self.agent_names.index(agent.name) + 1
50+
for i in range(len(self.agents)):
51+
if self.agents[(offset + i) % len(self.agents)] in agents:
52+
return self.agents[(offset + i) % len(self.agents)]
53+
54+
def select_speaker_msg(self, agents: List[Agent]):
3855
"""Return the message for selecting the next speaker."""
3956
return f"""You are in a role play game. The following roles are available:
4057
{self._participant_roles()}.
4158
4259
Read the following conversation.
43-
Then select the next role from {self.agent_names} to play. Only return the role."""
60+
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
4461

4562
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
4663
"""Select the next speaker."""
47-
selector.update_system_message(self.select_speaker_msg())
48-
49-
# Warn if GroupChat is underpopulated, without established changing behavior
50-
n_agents = len(self.agent_names)
51-
if n_agents < 3:
52-
logger.warning(
53-
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
54-
)
55-
64+
if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
65+
# find agents with the right function_map which contains the function name
66+
agents = [
67+
agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"])
68+
]
69+
if len(agents) == 1:
70+
# only one agent can execute the function
71+
return agents[0]
72+
elif not agents:
73+
# find all the agents with function_map
74+
agents = [agent for agent in self.agents if agent.function_map]
75+
if len(agents) == 1:
76+
return agents[0]
77+
elif not agents:
78+
raise ValueError(
79+
f"No agent can execute the function {self.messages[-1]['name']}. "
80+
"Please check the function_map of the agents."
81+
)
82+
else:
83+
agents = self.agents
84+
# Warn if GroupChat is underpopulated
85+
n_agents = len(agents)
86+
if n_agents < 3:
87+
logger.warning(
88+
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
89+
)
90+
selector.update_system_message(self.select_speaker_msg(agents))
5691
final, name = selector.generate_oai_reply(
5792
self.messages
5893
+ [
5994
{
6095
"role": "system",
61-
"content": f"Read the above conversation. Then select the next role from {self.agent_names} to play. Only return the role.",
96+
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
6297
}
6398
]
6499
)
65100
if not final:
66101
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
67-
return self.next_agent(last_speaker)
102+
return self.next_agent(last_speaker, agents)
68103
try:
69104
return self.agent_by_name(name)
70105
except ValueError:
71-
return self.next_agent(last_speaker)
106+
return self.next_agent(last_speaker, agents)
72107

73108
def _participant_roles(self):
74109
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])

test/agentchat/test_groupchat.py

+54-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,54 @@
1+
import pytest
12
import autogen
23

34

5+
def test_func_call_groupchat():
6+
agent1 = autogen.ConversableAgent(
7+
"alice",
8+
human_input_mode="NEVER",
9+
llm_config=False,
10+
default_auto_reply="This is alice sepaking.",
11+
)
12+
agent2 = autogen.ConversableAgent(
13+
"bob",
14+
human_input_mode="NEVER",
15+
llm_config=False,
16+
default_auto_reply="This is bob speaking.",
17+
function_map={"test_func": lambda x: x},
18+
)
19+
groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=3)
20+
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
21+
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})
22+
23+
assert len(groupchat.messages) == 3
24+
assert (
25+
groupchat.messages[-2]["role"] == "function"
26+
and groupchat.messages[-2]["name"] == "test_func"
27+
and groupchat.messages[-2]["content"] == "1"
28+
)
29+
assert groupchat.messages[-1]["name"] == "alice"
30+
31+
agent3 = autogen.ConversableAgent(
32+
"carol",
33+
human_input_mode="NEVER",
34+
llm_config=False,
35+
default_auto_reply="This is carol speaking.",
36+
function_map={"test_func": lambda x: x + 1},
37+
)
38+
groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=3)
39+
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
40+
agent3.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})
41+
42+
assert (
43+
groupchat.messages[-2]["role"] == "function"
44+
and groupchat.messages[-2]["name"] == "test_func"
45+
and groupchat.messages[-2]["content"] == "1"
46+
)
47+
assert groupchat.messages[-1]["name"] == "carol"
48+
49+
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})
50+
51+
452
def test_chat_manager():
553
agent1 = autogen.ConversableAgent(
654
"alice",
@@ -30,6 +78,9 @@ def test_chat_manager():
3078
agent2.initiate_chat(group_chat_manager, message="hello")
3179
assert len(groupchat.messages) == 2
3280

81+
with pytest.raises(ValueError):
82+
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})
83+
3384

3485
def test_plugin():
3586
# Give another Agent class ability to manage group chat
@@ -62,6 +113,7 @@ def test_plugin():
62113

63114

64115
if __name__ == "__main__":
116+
test_func_call_groupchat()
65117
# test_broadcast()
66-
# test_chat_manager()
67-
test_plugin()
118+
test_chat_manager()
119+
# test_plugin()

0 commit comments

Comments
 (0)