Skip to content

Commit

Permalink
Make groupchat & generation async, actually (microsoft#543)
Browse files Browse the repository at this point in the history
* make groupchat & generation async actually

* factored out func call pre-select; updated indecies

* fixed code format issue

* mark prepare agents subset as internal

* func renaming

* func inputs

* return agents

* Update test/agentchat/test_async.py

Co-authored-by: Chi Wang <[email protected]>

* Update notebook/agentchat_stream.ipynb

Co-authored-by: Chi Wang <[email protected]>

---------

Co-authored-by: Chi Wang <[email protected]>
Co-authored-by: Qingyun Wu <[email protected]>
  • Loading branch information
3 people authored and rlam3 committed Dec 19, 2023
1 parent 80ea6d9 commit ef9fe73
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 15 deletions.
17 changes: 15 additions & 2 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import copy
import functools
import json
import logging
from collections import defaultdict
Expand Down Expand Up @@ -133,9 +134,10 @@ def __init__(
self._reply_func_list = []
self.reply_at_receive = defaultdict(bool)
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.generate_async_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)

Expand Down Expand Up @@ -631,6 +633,17 @@ def generate_oai_reply(
)
return True, client.extract_text_or_function_call(response)[0]

async def a_generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Generate a reply using autogen.oai asynchronously."""
return await asyncio.get_event_loop().run_in_executor(
None, functools.partial(self.generate_oai_reply, messages=messages, sender=sender, config=config)
)

def generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down Expand Up @@ -697,7 +710,7 @@ def generate_function_call_reply(
return True, func_return
return False, None

async def generate_async_function_call_reply(
async def a_generate_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
Expand Down
61 changes: 50 additions & 11 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Tuple

from ..code_utils import content_str
from .agent import Agent
Expand Down Expand Up @@ -118,8 +118,7 @@ def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]:
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None

def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agent], List[Agent]]:
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
Expand Down Expand Up @@ -148,30 +147,35 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
]
if len(agents) == 1:
# only one agent can execute the function
return agents[0]
return agents[0], agents
elif not agents:
# find all the agents with function_map
agents = [agent for agent in self.agents if agent.function_map]
if len(agents) == 1:
return agents[0]
return agents[0], agents
elif not agents:
raise ValueError(
f"No agent can execute the function {self.messages[-1]['name']}. "
"Please check the function_map of the agents."
)

# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]

if self.speaker_selection_method.lower() == "manual":
selected_agent = self.manual_select_speaker(agents)
if selected_agent:
return selected_agent
elif self.speaker_selection_method.lower() == "round_robin":
return self.next_agent(last_speaker, agents)
selected_agent = self.next_agent(last_speaker, agents)
elif self.speaker_selection_method.lower() == "random":
return random.choice(agents)
selected_agent = random.choice(agents)
else:
selected_agent = None
return selected_agent, agents

def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
selected_agent, agents = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
context = self.messages + [{"role": "system", "content": self.select_speaker_prompt(agents)}]
Expand All @@ -196,6 +200,41 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
except ValueError:
return self.next_agent(last_speaker, agents)

async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
selected_agent, agents = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
final, name = await selector.a_generate_oai_reply(
self.messages
+ [
{
"role": "system",
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
}
]
)
if not final:
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents)

# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1:
name = next(iter(mentions))
else:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
)

# Return the result
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker, agents)

def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
Expand Down Expand Up @@ -342,7 +381,7 @@ async def a_run_chat(
break
try:
# select the next speaker
speaker = groupchat.select_speaker(speaker, self)
speaker = await groupchat.a_select_speaker(speaker, self)
# let the speaker speak
reply = await speaker.a_generate_reply(sender=self)
except KeyboardInterrupt:
Expand Down
2 changes: 1 addition & 1 deletion notebook/agentchat_stream.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@
" )\n",
" return False, None\n",
"\n",
"user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, 1, config={\"news_stream\": data})"
"user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={\"news_stream\": data})"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion test/agentchat/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def add_data_reply(recipient, messages, sender, config):
)
return False, None

user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, 1, config={"news_stream": data})
user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={"news_stream": data})

await user_proxy.a_initiate_chat(
assistant,
Expand Down

0 comments on commit ef9fe73

Please sign in to comment.