Skip to content

Commit

Permalink
Fix/async function and tool execution (#87)
Browse files Browse the repository at this point in the history
* async run group chat

* conversible agent allow async functions to generate reply

* test for async execution

---------

Co-authored-by: Qingyun Wu <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Oct 31, 2023
1 parent b432c1b commit 503c243
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 0 deletions.
73 changes: 73 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.generate_async_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)

def register_reply(
Expand Down Expand Up @@ -661,6 +662,28 @@ def generate_function_call_reply(
return True, func_return
return False, None

async def generate_async_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
):
"""Generate a reply using async function call."""
if config is None:
config = self
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
if "function_call" in message:
func_call = message["function_call"]
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
if func and asyncio.coroutines.iscoroutinefunction(func):
_, func_return = await self.a_execute_function(func_call)
return True, func_return

return False, None

def check_termination_and_human_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down Expand Up @@ -1002,6 +1025,56 @@ def execute_function(self, func_call):
"content": str(content),
}

async def a_execute_function(self, func_call):
"""Execute an async function call and return the result.
Override this function to modify the way async functions are executed.
Args:
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
Returns:
A tuple of (is_exec_success, result_dict).
is_exec_success (boolean): whether the execution is successful.
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
"""
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)

is_exec_success = False
if func is not None:
# Extract arguments from a json-like string and put it into a dict.
input_string = self._format_json_str(func_call.get("arguments", "{}"))
try:
arguments = json.loads(input_string)
except json.JSONDecodeError as e:
arguments = None
content = f"Error: {e}\n You argument should follow json format."

# Try to execute the function
if arguments is not None:
print(
colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"),
flush=True,
)
try:
if asyncio.coroutines.iscoroutinefunction(func):
content = await func(**arguments)
else:
# Fallback to sync function if the function is not async
content = func(**arguments)
is_exec_success = True
except Exception as e:
content = f"Error: {e}"
else:
content = f"Error: Function {func_name} not found."

return is_exec_success, {
"name": func_name,
"role": "function",
"content": str(content),
}

def generate_init_message(self, **context) -> Union[str, Dict]:
"""Generate the initial message for the agent.
Expand Down
50 changes: 50 additions & 0 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ def __init__(
system_message=system_message,
**kwargs,
)
# Order of register_reply is important.
# Allow sync chat if initiated using initiate_chat
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
# Allow async chat if initiated using a_initiate_chat
self.register_reply(Agent, GroupChatManager.a_run_chat, config=groupchat, reset_config=GroupChat.reset)

# self._random = random.Random(seed)

def run_chat(
Expand Down Expand Up @@ -177,3 +182,48 @@ def run_chat(
speaker.send(reply, self, request_reply=False)
message = self.last_message(speaker)
return True, None

async def a_run_chat(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[GroupChat] = None,
):
"""Run a group chat asynchronously."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
speaker = sender
groupchat = config
for i in range(groupchat.max_round):
# set the name to speaker's name if the role is not function
if message["role"] != "function":
message["name"] = speaker.name
groupchat.messages.append(message)
# broadcast the message to all agents except the speaker
for agent in groupchat.agents:
if agent != speaker:
await self.a_send(message, agent, request_reply=False, silent=True)
if i == groupchat.max_round - 1:
# the last round
break
try:
# select the next speaker
speaker = groupchat.select_speaker(speaker, self)
# let the speaker speak
reply = await speaker.a_generate_reply(sender=self)
except KeyboardInterrupt:
# let the admin agent speak if interrupted
if groupchat.admin_name in groupchat.agent_names:
# admin agent is one of the participants
speaker = groupchat.agent_by_name(groupchat.admin_name)
reply = await speaker.a_generate_reply(sender=self)
else:
# admin agent is not found in the participants
raise
if reply is None:
break
# The speaker sends the message without requesting a reply
await speaker.a_send(reply, self, request_reply=False)
message = self.last_message(speaker)
return True, None
61 changes: 61 additions & 0 deletions test/test_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,68 @@ def get_number():
assert user.execute_function(func_call)[1]["content"] == "42"


@pytest.mark.asyncio
async def test_a_execute_function():
from autogen.agentchat import UserProxyAgent
import time

# Create an async function
async def add_num(num_to_be_added):
given_num = 10
time.sleep(1)
return num_to_be_added + given_num

user = UserProxyAgent(name="test", function_map={"add_num": add_num})
correct_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}

# Asset coroutine doesn't match.
assert user.execute_function(func_call=correct_args)[1]["content"] != "15"
# Asset awaited coroutine does match.
assert (await user.a_execute_function(func_call=correct_args))[1]["content"] == "15"

# function name called is wrong or doesn't exist
wrong_func_name = {"name": "subtract_num", "arguments": '{ "num_to_be_added": 5 }'}
assert "Error: Function" in (await user.a_execute_function(func_call=wrong_func_name))[1]["content"]

# arguments passed is not in correct json format
wrong_json_format = {
"name": "add_num",
"arguments": '{ "num_to_be_added": 5, given_num: 10 }',
} # should be "given_num" with quotes
assert (
"You argument should follow json format."
in (await user.a_execute_function(func_call=wrong_json_format))[1]["content"]
)

# function execution error with wrong arguments passed
wrong_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5, "given_num": 10 }'}
assert "Error: " in (await user.a_execute_function(func_call=wrong_args))[1]["content"]

# 2. test calling a class method
class AddNum:
def __init__(self, given_num):
self.given_num = given_num

def add(self, num_to_be_added):
self.given_num = num_to_be_added + self.given_num
return self.given_num

user = UserProxyAgent(name="test", function_map={"add_num": AddNum(given_num=10).add})
func_call = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
assert (await user.a_execute_function(func_call=func_call))[1]["content"] == "15"
assert (await user.a_execute_function(func_call=func_call))[1]["content"] == "20"

# 3. test calling a function with no arguments
def get_number():
return 42

user = UserProxyAgent("user", function_map={"get_number": get_number})
func_call = {"name": "get_number", "arguments": "{}"}
assert (await user.a_execute_function(func_call))[1]["content"] == "42"


if __name__ == "__main__":
test_json_extraction()
test_execute_function()
test_a_execute_function()
test_eval_math_responses()

0 comments on commit 503c243

Please sign in to comment.