Skip to content

Commit 6e23871

Browse files
kittyandrewsonichiqingyun-wu
authored
Make groupchat & generation async, actually (#543)
* make groupchat & generation async actually * factored out func call pre-select; updated indecies * fixed code format issue * mark prepare agents subset as internal * func renaming * func inputs * return agents * Update test/agentchat/test_async.py Co-authored-by: Chi Wang <[email protected]> * Update notebook/agentchat_stream.ipynb Co-authored-by: Chi Wang <[email protected]> --------- Co-authored-by: Chi Wang <[email protected]> Co-authored-by: Qingyun Wu <[email protected]>
1 parent 379d7bd commit 6e23871

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

autogen/agentchat/conversable_agent.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import copy
3+
import functools
34
import json
45
import logging
56
from collections import defaultdict
@@ -133,9 +134,10 @@ def __init__(
133134
self._reply_func_list = []
134135
self.reply_at_receive = defaultdict(bool)
135136
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
137+
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply)
136138
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
137139
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
138-
self.register_reply([Agent, None], ConversableAgent.generate_async_function_call_reply)
140+
self.register_reply([Agent, None], ConversableAgent.a_generate_function_call_reply)
139141
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
140142
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)
141143

@@ -631,6 +633,17 @@ def generate_oai_reply(
631633
)
632634
return True, client.extract_text_or_function_call(response)[0]
633635

636+
async def a_generate_oai_reply(
637+
self,
638+
messages: Optional[List[Dict]] = None,
639+
sender: Optional[Agent] = None,
640+
config: Optional[Any] = None,
641+
) -> Tuple[bool, Union[str, Dict, None]]:
642+
"""Generate a reply using autogen.oai asynchronously."""
643+
return await asyncio.get_event_loop().run_in_executor(
644+
None, functools.partial(self.generate_oai_reply, messages=messages, sender=sender, config=config)
645+
)
646+
634647
def generate_code_execution_reply(
635648
self,
636649
messages: Optional[List[Dict]] = None,
@@ -697,7 +710,7 @@ def generate_function_call_reply(
697710
return True, func_return
698711
return False, None
699712

700-
async def generate_async_function_call_reply(
713+
async def a_generate_function_call_reply(
701714
self,
702715
messages: Optional[List[Dict]] = None,
703716
sender: Optional[Agent] = None,

autogen/agentchat/groupchat.py

+50-11
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import sys
55
from dataclasses import dataclass
6-
from typing import Dict, List, Optional, Union
6+
from typing import Dict, List, Optional, Union, Tuple
77

88
from ..code_utils import content_str
99
from .agent import Agent
@@ -118,8 +118,7 @@ def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]:
118118
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
119119
return None
120120

121-
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
122-
"""Select the next speaker."""
121+
def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agent], List[Agent]]:
123122
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
124123
raise ValueError(
125124
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
@@ -148,30 +147,35 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
148147
]
149148
if len(agents) == 1:
150149
# only one agent can execute the function
151-
return agents[0]
150+
return agents[0], agents
152151
elif not agents:
153152
# find all the agents with function_map
154153
agents = [agent for agent in self.agents if agent.function_map]
155154
if len(agents) == 1:
156-
return agents[0]
155+
return agents[0], agents
157156
elif not agents:
158157
raise ValueError(
159158
f"No agent can execute the function {self.messages[-1]['name']}. "
160159
"Please check the function_map of the agents."
161160
)
162-
163161
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
164162
agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]
165163

166164
if self.speaker_selection_method.lower() == "manual":
167165
selected_agent = self.manual_select_speaker(agents)
168-
if selected_agent:
169-
return selected_agent
170166
elif self.speaker_selection_method.lower() == "round_robin":
171-
return self.next_agent(last_speaker, agents)
167+
selected_agent = self.next_agent(last_speaker, agents)
172168
elif self.speaker_selection_method.lower() == "random":
173-
return random.choice(agents)
169+
selected_agent = random.choice(agents)
170+
else:
171+
selected_agent = None
172+
return selected_agent, agents
174173

174+
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
175+
"""Select the next speaker."""
176+
selected_agent, agents = self._prepare_and_select_agents(last_speaker)
177+
if selected_agent:
178+
return selected_agent
175179
# auto speaker selection
176180
selector.update_system_message(self.select_speaker_msg(agents))
177181
context = self.messages + [{"role": "system", "content": self.select_speaker_prompt(agents)}]
@@ -196,6 +200,41 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
196200
except ValueError:
197201
return self.next_agent(last_speaker, agents)
198202

203+
async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
204+
"""Select the next speaker."""
205+
selected_agent, agents = self._prepare_and_select_agents(last_speaker)
206+
if selected_agent:
207+
return selected_agent
208+
# auto speaker selection
209+
selector.update_system_message(self.select_speaker_msg(agents))
210+
final, name = await selector.a_generate_oai_reply(
211+
self.messages
212+
+ [
213+
{
214+
"role": "system",
215+
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
216+
}
217+
]
218+
)
219+
if not final:
220+
# the LLM client is None, thus no reply is generated. Use round robin instead.
221+
return self.next_agent(last_speaker, agents)
222+
223+
# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
224+
mentions = self._mentioned_agents(name, agents)
225+
if len(mentions) == 1:
226+
name = next(iter(mentions))
227+
else:
228+
logger.warning(
229+
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
230+
)
231+
232+
# Return the result
233+
try:
234+
return self.agent_by_name(name)
235+
except ValueError:
236+
return self.next_agent(last_speaker, agents)
237+
199238
def _participant_roles(self, agents: List[Agent] = None) -> str:
200239
# Default to all agents registered
201240
if agents is None:
@@ -342,7 +381,7 @@ async def a_run_chat(
342381
break
343382
try:
344383
# select the next speaker
345-
speaker = groupchat.select_speaker(speaker, self)
384+
speaker = await groupchat.a_select_speaker(speaker, self)
346385
# let the speaker speak
347386
reply = await speaker.a_generate_reply(sender=self)
348387
except KeyboardInterrupt:

notebook/agentchat_stream.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@
238238
" )\n",
239239
" return False, None\n",
240240
"\n",
241-
"user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, 1, config={\"news_stream\": data})"
241+
"user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={\"news_stream\": data})"
242242
]
243243
},
244244
{

test/agentchat/test_async.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ async def add_data_reply(recipient, messages, sender, config):
146146
)
147147
return False, None
148148

149-
user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, 1, config={"news_stream": data})
149+
user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={"news_stream": data})
150150

151151
await user_proxy.a_initiate_chat(
152152
assistant,

0 commit comments

Comments
 (0)