From cd7d91da172a3d438ed37586ce52d52e9ae6dd78 Mon Sep 17 00:00:00 2001 From: Shaokun Zhang Date: Tue, 26 Mar 2024 16:31:02 -0400 Subject: [PATCH] Integrate AgentOptimizer (#1767) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * draft agent optimizer * refactor * remove * change openai config interface * notebook * update blog * add test * clean up * redir * update * update interface * change model name * move to contrib * Update autogen/agentchat/contrib/agent_optimizer.py Co-authored-by: Jack Gerrits --------- Co-authored-by: “skzhang1” <“shaokunzhang529@gmail.com”> Co-authored-by: Beibin Li Co-authored-by: Jieyu Zhang Co-authored-by: Jack Gerrits --- autogen/agentchat/contrib/agent_optimizer.py | 440 ++++++++++++++ notebook/agentchat_agentoptimizer.ipynb | 546 +++--------------- .../agentchat/contrib/test_agent_optimizer.py | 110 ++++ .../blog/2023-12-23-AgentOptimizer/index.mdx | 144 ++--- 4 files changed, 708 insertions(+), 532 deletions(-) create mode 100644 autogen/agentchat/contrib/agent_optimizer.py create mode 100644 test/agentchat/contrib/test_agent_optimizer.py diff --git a/autogen/agentchat/contrib/agent_optimizer.py b/autogen/agentchat/contrib/agent_optimizer.py new file mode 100644 index 000000000000..711874efc8fa --- /dev/null +++ b/autogen/agentchat/contrib/agent_optimizer.py @@ -0,0 +1,440 @@ +from autogen.code_utils import execute_code +from typing import List, Dict, Optional +import json +import copy +import autogen + +ADD_FUNC = { + "type": "function", + "function": { + "name": "add_function", + "description": "Add a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the function in the code implementation."}, + "description": {"type": "string", "description": "A short description of the function."}, + "arguments": { + "type": "string", + "description": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { "url": { "type": "string", "description": "The URL", }}. Please avoid the error \'array schema missing items\' when using array type.', + }, + "packages": { + "type": "string", + "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.", + }, + "code": { + "type": "string", + "description": "The implementation in Python. Do not include the function declaration.", + }, + }, + "required": ["name", "description", "arguments", "packages", "code"], + }, + }, +} + +REVISE_FUNC = { + "type": "function", + "function": { + "name": "revise_function", + "description": "Revise a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the function in the code implementation."}, + "description": {"type": "string", "description": "A short description of the function."}, + "arguments": { + "type": "string", + "description": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { "url": { "type": "string", "description": "The URL", }}. Please avoid the error \'array schema missing items\' when using array type.', + }, + "packages": { + "type": "string", + "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.", + }, + "code": { + "type": "string", + "description": "The implementation in Python. Do not include the function declaration.", + }, + }, + "required": ["name", "description", "arguments", "packages", "code"], + }, + }, +} + +REMOVE_FUNC = { + "type": "function", + "function": { + "name": "remove_function", + "description": "Remove one function in the context of the conversation. Once remove one function, the assistant will not use this function in future conversation.", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the function in the code implementation."} + }, + "required": ["name"], + }, + }, +} + +OPT_PROMPT = """You are a function optimizer. Your task is to maintain a list of functions for the assistant according to the existing function list and conversation history that happens between the assistant and the user. +You can perform one of the following four actions to manipulate the function list using the functions you have: +1. Revise one existing function (using revise_function). +2. Remove one existing function (using remove_function). +3. Add one new function (using add_function). +4. Directly return "TERMINATE" to me if no more actions are needed for the current function list. + +Below are the principles that you need to follow for taking these four actions. +(1) Revise one existing function: +1. Pay more attention to the failed tasks and corresponding error information, and optimize the function used in these tasks according to the conversation history if needed. +2. A failed function call can occur due to incorrect input arguments (missing arguments) or an incorrect function code implementation. You should focus more on the function code implementation and make it easy to get success function call. +3. Do not revise the function that you think works well and plays a critical role in solving the problems according to the conversation history. Only making revisions if needed. +4. Sometimes, a NameError may occur. To fix this error, you can either revise the name of the function in the code implementation or revise the name of the function call to make these two names consistent. +(2) Remove one existing function: +1. Only remove the function that you think is not needed anymore in future tasks. +(3) Add one new function: +1. The added function should be general enough to be used in future tasks. For instance, if you encounter a problem that this function can solve, or one step of it, you can use the generated function directly instead of starting from scratch +2. The added new function should solve a higher-level question that encompasses the original query and extend the code's functionality to make it more versatile and widely applicable. +3. Replace specific strings or variable names with general variables to enhance the tool's applicability to various queries. All names used inside the function should be passed in as arguments. +Below is an example of a function that potentially deserves to be adde in solving MATH problems, which can be used to solve a higher-level question: +{{ + \"name\": \"evaluate_expression\", + \"description\": \"Evaluate arithmetic or mathematical expressions provided as strings.\", + \"arguments\": {{ + \"expression\": {{ + \"type\": \"string\", + \"description\": \"The mathematical expression to evaluate.\" + }} + }}, + \"packages\": \"sympy\", + \"code\": \"from sympy import sympify, SympifyError\\n\\ndef evaluate_expression(expression):\\n try:\\n result = sympify(expression)\\n if result.is_number:\\n result = float(result)\\n else:\\n result = str(result)\\n return result\\n except SympifyError as e:\\n return str(e)\" +}} +(4) Directly return "TERMINATE": +If you think there is no need to perform any other actions for the current function list since the current list is optimal more actions will harm the performance in future tasks. Please directly reply to me with "TERMINATE". + +One function signature includes the following five elements: +1. Function name +2. Function description +3. JSON schema of arguments encoded as a string +4. A list of package names imported by the function packages +5. The code implementation + +Below are the signatures of the current functions: +List A: {best_functions}. +The following list are the function signatures that you have after taking {actions_num} actions to manipulate List A: +List B: {incumbent_functions}. + +{accumulated_experience} + +Here are {best_conversations_num} conversation histories of solving {best_conversations_num} tasks using List A. +History: +{best_conversations_history} + +{statistic_informations} + +According to the information I provide, please take one of four actions to manipulate list B using the functions you know. +Instead of returning TERMINATE directly or taking no action, you should try your best to optimize the function list. Only take no action if you really think the current list is optimal, as more actions will harm performance in future tasks. +Even adding a general function that can substitute the assistant’s repeated suggestions of Python code with the same functionality could also be helpful. +""" + + +def execute_func(name, packages, code, **args): + """ + The wrapper for generated functions. + """ + pip_install = ( + f"""print("Installing package: {packages}")\nsubprocess.run(["pip", "-qq", "install", "{packages}"])""" + if packages + else "" + ) + str = f""" +import subprocess +{pip_install} +print("Result of {name} function execution:") +{code} +args={args} +result={name}(**args) +if result is not None: print(result) +""" + print(f"execute_code:\n{str}") + result = execute_code(str, use_docker="shaokun529/evoagent:v1") + if result[0] != 0: + raise Exception("Error in executing function:" + result[1]) + print(f"Result: {result[1]}") + return result[1] + + +class AgentOptimizer: + """ + Base class for optimizing AutoGen agents. Specifically, it is used to optimize the functions used in the agent. + More information could be found in the following paper: https://arxiv.org/abs/2402.11359. + """ + + def __init__( + self, + max_actions_per_step: int, + config_file_or_env: Optional[str] = "OAI_CONFIG_LIST", + config_file_location: Optional[str] = "", + optimizer_model: Optional[str] = "gpt-4-1106-preview", + ): + """ + (These APIs are experimental and may change in the future.) + Args: + max_actions_per_step (int): the maximum number of actions that the optimizer can take in one step. + config_file_or_env: path or environment of the OpenAI api configs. + config_file_location: the location of the OpenAI config file. + optimizer_model: the model used for the optimizer. + """ + self.max_actions_per_step = max_actions_per_step + self._max_trials = 3 + self.optimizer_model = optimizer_model + + self._trial_conversations_history = [] + self._trial_conversations_performance = [] + self._trial_functions = [] + + self._best_conversations_history = [] + self._best_conversations_performance = [] + self._best_functions = [] + + self._failure_functions_performance = [] + self._best_performance = -1 + + config_list = autogen.config_list_from_json( + config_file_or_env, + file_location=config_file_location, + filter_dict={"model": [self.optimizer_model]}, + ) + if len(config_list) == 0: + raise RuntimeError("No valid openai config found in the config file or environment variable.") + self._client = autogen.OpenAIWrapper(config_list=config_list) + + def record_one_conversation(self, conversation_history: List[Dict], is_satisfied: bool = None): + """ + record one conversation history. + Args: + conversation_history (List[Dict]): the chat messages of the conversation. + is_satisfied (bool): whether the user is satisfied with the solution. If it is none, the user will be asked to input the satisfaction. + """ + if is_satisfied is None: + reply = input( + "Please provide whether the user is satisfied with the solution. 1 represents satisfied. 0 represents not satisfied. Press enter to submit. \n" + ) + assert reply in [ + "0", + "1", + ], "The input is invalid. Please input 1 or 0. 1 represents satisfied. 0 represents not satisfied." + is_satisfied = True if reply == "1" else False + self._trial_conversations_history.append( + {"Conversation {i}".format(i=len(self._trial_conversations_history)): conversation_history} + ) + self._trial_conversations_performance.append( + {"Conversation {i}".format(i=len(self._trial_conversations_performance)): 1 if is_satisfied else 0} + ) + + def step(self): + """ + One step of training. It will return register_for_llm and register_for_executor at each iteration, + which are subsequently utilized to update the assistant and executor agents, respectively. + See example: https://github.com/microsoft/autogen/blob/main/notebook/agentchat_agentoptimizer.ipynb + """ + performance = sum(sum(d.values()) for d in self._trial_conversations_performance) / len( + self._trial_conversations_performance + ) + + if performance < self._best_performance: + self._failure_functions_performance.append({"functions": self._trial_functions, "performance": performance}) + self._failure_functions_performance = sorted( + self._failure_functions_performance, key=lambda x: x["performance"] + ) + else: + self._failure_functions_performance = [] + self._best_performance = performance + self._best_functions = copy.deepcopy(self._trial_functions) + self._best_conversations_history = copy.deepcopy(self._trial_conversations_history) + self._best_conversations_performance = copy.deepcopy(self._trial_conversations_performance) + self._trial_conversations_history = [] + self._trial_conversations_performance = [] + + best_functions = copy.deepcopy(self._best_functions) + incumbent_functions = copy.deepcopy(self._best_functions) + failure_experience_prompt, statistic_prompt = self._construct_intermediate_prompt() + + for action_index in range(self.max_actions_per_step): + prompt = OPT_PROMPT.format( + best_conversations_history=self._best_conversations_history, + best_conversations_num=len(self._best_conversations_history), + actions_num=action_index, + best_functions=best_functions, + incumbent_functions=incumbent_functions, + accumerated_experience=failure_experience_prompt, + statistic_informations=statistic_prompt, + ) + messages = [{"role": "user", "content": prompt}] + for _ in range(self._max_trials): + response = self._client.create( + messages=messages, tools=[ADD_FUNC, REVISE_FUNC, REMOVE_FUNC], tool_choice="auto" + ) + actions = response.choices[0].message.tool_calls + if self._validate_actions(actions, incumbent_functions): + break + if actions is not None and self._validate_actions(actions, incumbent_functions): + incumbent_functions = self._update_function_call(incumbent_functions, actions) + + remove_functions = list( + set([key for dictionary in self._trial_functions for key in dictionary.keys()]) + - set([key for dictionary in incumbent_functions for key in dictionary.keys()]) + ) + + register_for_llm = [] + register_for_exector = {} + for name in remove_functions: + register_for_llm.append({"func_sig": {"name": name}, "is_remove": True}) + register_for_exector.update({name: None}) + for func in incumbent_functions: + register_for_llm.append( + { + "func_sig": { + "name": func.get("name"), + "description": func.get("description"), + "parameters": {"type": "object", "properties": func.get("arguments")}, + }, + "is_remove": False, + } + ) + register_for_exector.update( + { + func.get("name"): lambda **args: execute_func( + func.get("name"), func.get("packages"), func.get("code"), **args + ) + } + ) + + self._trial_functions = incumbent_functions + return register_for_llm, register_for_exector + + def reset_optimizer(self): + """ + reset the optimizer. + """ + + self._trial_conversations_history = [] + self._trial_conversations_performance = [] + self._trial_functions = [] + + self._best_conversations_history = [] + self._best_conversations_performance = [] + self._best_functions = [] + + self._best_performance = -1 + self._failure_functions_performance = [] + + def _update_function_call(self, incumbent_functions, actions): + """ + update function call. + """ + + formated_actions = [] + for action in actions: + func = json.loads(action.function.arguments.strip('"')) + func["action_name"] = action.function.name + + if func.get("action_name") == "remove_function": + item = { + "action_name": func.get("action_name"), + "name": func.get("name"), + } + else: + item = { + "action_name": func.get("action_name"), + "name": func.get("name"), + "description": func.get("description"), + "arguments": json.loads(func.get("arguments").strip('"')), + "packages": func.get("packages"), + "code": func.get("code"), + } + formated_actions.append(item) + actions = formated_actions + + for action in actions: + name, description, arguments, packages, code, action_name = ( + action.get("name"), + action.get("description"), + action.get("arguments"), + action.get("packages"), + action.get("code"), + action.get("action_name"), + ) + if action_name == "remove_function": + incumbent_functions = [item for item in incumbent_functions if item["name"] != name] + else: + incumbent_functions = [item for item in incumbent_functions if item["name"] != name] + incumbent_functions.append( + { + "name": name, + "description": description, + "arguments": arguments, + "packages": packages, + "code": code, + } + ) + + return incumbent_functions + + def _construct_intermediate_prompt(self): + """ + construct intermediate prompts. + """ + if len(self._failure_functions_performance) != 0: + failure_experience_prompt = "We also provide more examples for different functions and their corresponding performance (0-100).\n The following function signatures are arranged in are arranged in ascending order based on their performance, where higher performance indicate better quality." + failure_experience_prompt += "\n" + for item in self._failure_functions_performance: + failure_experience_prompt += "Function: \n" + str(item["functions"]) + "\n" + failure_experience_prompt += "Performance: \n" + str(item["performance"]) + "\n" + else: + failure_experience_prompt = "\n" + + if len(self._best_conversations_performance) != 0: + statistic_prompt = "The following table shows the statistical information for solving each task in each conversation and indicates, whether the result is satisfied by the users. 1 represents satisfied. 0 represents not satisfied." + statistic_prompt += "\n" + for item in self._best_conversations_performance: + statistic_prompt += str(item) + "\n" + else: + statistic_prompt = "\n" + + return failure_experience_prompt, statistic_prompt + + def _validate_actions(self, actions, incumbent_functions): + """ + validate whether the proposed actions are feasible. + """ + if actions is None: + return True + else: + # val json format + for action in actions: + function_args = action.function.arguments + try: + function_args = json.loads(function_args.strip('"')) + if "arguments" in function_args.keys(): + json.loads(function_args.get("arguments").strip('"')) + except Exception as e: + print("JSON is invalid:", e) + return False + # val syntax + for action in actions: + if action.function.name != "remove_function": + function_args = json.loads(action.function.arguments.strip('"')) + code = function_args.get("code") + try: + compile(code, "", "exec") + print("successfully compiled") + except Exception as e: + print("Syntax is invalid:", e) + return False + for action in actions: + action_name = action.function.name + if action_name == "remove_function": + function_args = json.loads(action.function.arguments.strip('"')) + if function_args.get("name") not in [item["name"] for item in incumbent_functions]: + print("The function you want to remove does not exist.") + return False + return True diff --git a/notebook/agentchat_agentoptimizer.ipynb b/notebook/agentchat_agentoptimizer.ipynb index a619ce551385..4b7de9715be1 100644 --- a/notebook/agentchat_agentoptimizer.ipynb +++ b/notebook/agentchat_agentoptimizer.ipynb @@ -11,14 +11,15 @@ "\n", "In traditional ML pipeline, we train a model by updating its parameter according to the loss on the training set, while in the era of LLM agents, how should we train an agent? Here, we take an initial step towards the agent training. Inspired by the [function calling](https://platform.openai.com/docs/guides/function-calling) capabilities provided by OpenAI, we draw an analogy between model parameters and agent functions/skills, and update agent’s functions/skills based on its historical performance on the training set. As an agentic way of training an agent, our approach help enhance the agents’ abilities without requiring access to the LLMs parameters.\n", "\n", - "In this notebook, we introduce a new class, ‘AgentOptimizer’, which is able to improve the function list of one Assistant-UserProxy pair according to the historical conversation histories. This feature would support agents in improving their ability to solve problems of the same type as previous tasks.\n", - "Specifically, given a set of training data, AgentOptimizer would iteratively prompt the LLM to optimize the existing function list of the AssistantAgent and UserProxyAgent with code implementation if necessary.\n", + "In this notebook, we introduce a new class, ‘AgentOptimizer’, which is able to improve the function list of one Assistant-UserProxy pair according to the historical conversation histories.\n", + "This feature would support agents in improving their ability to solve problems of the same type as previous tasks.\n", + "Specifically, given a set of training data, AgentOptimizer would iteratively prompt the LLM to optimize the existing function list of the AssistantAgent and UserProxyAgent with code implementation if necessary. It also includes two strategies, roll-back, and early-stop, to streamline the training process.\n", "In the example scenario, we test the proposed AgentOptimizer in solving problems from the [MATH dataset](https://github.com/hendrycks/math). \n", "\n", - "Paper is coming soon!\n", - "\n", "![AgentEval](../website/blog/2023-12-23-AgentOptimizer/img/agentoptimizer.png)\n", "\n", + "More information could be found in the [paper](https://arxiv.org/abs/2402.11359).\n", + "\n", "Authors:\n", "- [Shaokun Zhang](https://github.com/skzhang1), Ph.D. student at the The Pennsylvania State University\n", "- [Jieyu Zhang](https://jieyuz2.github.io), Ph.D. student at the University of Washington" @@ -26,351 +27,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ - "import json\n", - "import os\n", "from typing import Any, Callable, Dict, List, Optional, Tuple, Union\n", - "\n", - "from openai import AzureOpenAI, BadRequestError\n", - "\n", - "import autogen\n", - "from autogen.agentchat import Agent\n", + "from autogen.agentchat.contrib.agent_optimizer import AgentOptimizer\n", "from autogen.agentchat.contrib.math_user_proxy_agent import MathUserProxyAgent\n", - "from autogen.code_utils import execute_code, extract_code\n", - "from autogen.math_utils import get_answer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# AgentOptimizer\n", - "\n", - "AgentOptimizer is a class that is designed to improve the agents through optimizing its function call. It contains two core methods:\n", - "\n", - "1. `step()`: `step()` has three inputs: previous conversation history (history), the statistical information of solving previous problems (statistic), and the signature of current functions (func_signature). The output is a series of actions to manipulate the current functions.\n", - "\n", - "2. `update_function_call()`: This method updates the functions registered in the agents according to the actions from `step()`. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class AgentOptimizer:\n", - " OPT_PROMPT = \"\"\"You are a function optimizer. Your task is to maintain a list of functions for the assistant according to the existing function list and conversation history that happens between the assistant and the user.\n", - "You can perform one of the following four actions to manipulate the function list using the functions you have:\n", - "1. Revise one existing function (using revise_function).\n", - "2. Remove one existing function (using remove_function).\n", - "3. Add one new function (using add_function).\n", - "4. Directly return \"TERMINATE\" to me if no more actions are needed for the current function list.\n", - "\n", - "Below are the principles that you need to follow for taking these four actions.\n", - "(1) Revise one existing function:\n", - "1. Pay more attention to the failed tasks and corresponding error information, and optimize the function used in these tasks according to the conversation history if needed.\n", - "2. A failed function call can occur due to incorrect input arguments (missing arguments) or an incorrect function code implementation. You should focus more on the function code implementation and make it easy to get success function call.\n", - "3. Do not revise the function that you think works well and plays a critical role in solving the problems according to the conversation history. Only making revisions if needed.\n", - "4. Sometimes, a NameError may occur. To fix this error, you can either revise the name of the function in the code implementation or revise the name of the function call to make these two names consistent.\n", - "(2) Remove one existing function:\n", - "1. Only remove the function that you think is not needed anymore in future tasks.\n", - "(3) Add one new function:\n", - "1. The added new function should solve a higher-level question that encompasses the original query and extend the code's functionality to make it more versatile and widely applicable.\n", - "2. The added new function should solve queries of the same type, based on common reasoning steps without mentioning specific object names or entity terms.\n", - "3. Name the function and write the description concerning both the core reasoning pattern and data organization format, without referencing specific objects. The name of the function MUST be the same with the function name in the code you generated.\n", - "4. Replace specific strings or variable names with general variables to enhance the tool's applicability to various queries. All names used inside the function should be passed in as arguments.\n", - "(4) Directly return \"TERMINATE\":\n", - "If you think there is no need to perform any other actions for the current function list since the current list is optimal more actions will harm the performance in future tasks. Please directly reply to me with \"TERMINATE\".\n", - "\n", - "One function signature includes the following five elements:\n", - "1. Function name\n", - "2. Function description\n", - "3. JSON schema of arguments encoded as a string\n", - "4. A list of package names imported by the function packages\n", - "5. The code implementation\n", - "\n", - "Below are the signatures of the current functions:\n", - "List A: {signiture}.\n", - "The success rate (performance) with this function list is {success_rate}.\n", - "The following list are the function signatures that you have after taking {actions_num} actions in our previous conversations:\n", - "List B: {after_signiture}.\n", - "Here are {conversation_num} conversation histories of solving {conversation_num} tasks.\n", - "History:\n", - "{history}\n", - "The following table shows the statistical information for solving each task in each conversation and indicates whether each task was successfully solved.\n", - "1 represents correct. 0 represents wrong.\n", - "statistic:\n", - "{statistic}\n", - "\n", - "According to the information I provide, please take one of four actions to manipulate list B using the functions you know.\n", - " \"\"\"\n", - "\n", - " ADD_FUNC = {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"add_function\",\n", - " \"description\": \"Add a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"name\": {\"type\": \"string\", \"description\": \"The name of the function in the code implementation.\"},\n", - " \"description\": {\"type\": \"string\", \"description\": \"A short description of the function.\"},\n", - " \"arguments\": {\n", - " \"type\": \"string\",\n", - " \"description\": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { \"url\": { \"type\": \"string\", \"description\": \"The URL\", }}. Please avoid the error \\'array schema missing items\\' when using array type.',\n", - " },\n", - " \"packages\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.\",\n", - " },\n", - " \"code\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The implementation in Python. Do not include the function declaration.\",\n", - " },\n", - " },\n", - " \"required\": [\"name\", \"description\", \"arguments\", \"packages\", \"code\"],\n", - " },\n", - " },\n", - " }\n", - "\n", - " REVISE_FUNC = {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"revise_function\",\n", - " \"description\": \"Revise a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"name\": {\"type\": \"string\", \"description\": \"The name of the function in the code implementation.\"},\n", - " \"description\": {\"type\": \"string\", \"description\": \"A short description of the function.\"},\n", - " \"arguments\": {\n", - " \"type\": \"string\",\n", - " \"description\": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { \"url\": { \"type\": \"string\", \"description\": \"The URL\", }}. Please avoid the error \\'array schema missing items\\' when using array type.',\n", - " },\n", - " \"packages\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.\",\n", - " },\n", - " \"code\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The implementation in Python. Do not include the function declaration.\",\n", - " },\n", - " },\n", - " \"required\": [\"name\", \"description\", \"arguments\", \"packages\", \"code\"],\n", - " },\n", - " },\n", - " }\n", - "\n", - " REMOVE_FUNC = {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"remove_function\",\n", - " \"description\": \"Remove one function in the context of the conversation. Once remove one function, the assistant will not use this function in future conversation.\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"name\": {\"type\": \"string\", \"description\": \"The name of the function in the code implementation.\"}\n", - " },\n", - " \"required\": [\"name\"],\n", - " },\n", - " },\n", - " }\n", - "\n", - " def __init__(self, OAI_config, action_num=3, each_action_max_trials=10):\n", - " self._action_num = action_num\n", - " self._each_action_max_trials = each_action_max_trials\n", - " os.environ[\"AZURE_OPENAI_API_KEY\"] = OAI_config[\"AZURE_OPENAI_API_KEY\"] # TODO: input key into client\n", - " self._client = AzureOpenAI(\n", - " api_version=OAI_config[\"api_version\"],\n", - " azure_endpoint=OAI_config[\"azure_endpoint\"],\n", - " )\n", - " self.model = \"gpt-4-1106-preview\"\n", - "\n", - " def _val_json(self, actions):\n", - " if actions is None:\n", - " return True\n", - " else:\n", - " for action in actions:\n", - " function_args = action.function.arguments\n", - " try:\n", - " function_args = json.loads(function_args.strip('\"'))\n", - " if \"arguments\" in function_args.keys():\n", - " json.loads(function_args.get(\"arguments\").strip('\"'))\n", - " except Exception as e:\n", - " print(\"JSON is invalid:\", e)\n", - " return False\n", - " return True\n", - "\n", - " def _val_remove(self, actions, after_signiture):\n", - " if actions is None:\n", - " return True\n", - " else:\n", - " for action in actions:\n", - " action_name = action.function.name\n", - " if action_name == \"remove_function\":\n", - " function_args = json.loads(action.function.arguments.strip('\"'))\n", - " if function_args.get(\"name\") not in [item[\"name\"] for item in after_signiture]:\n", - " print(\"The function you want to remove does not exist.\")\n", - " return False\n", - " return True\n", - "\n", - " def _val_syntax(self, actions):\n", - " if actions is None:\n", - " return True\n", - " else:\n", - " for action in actions:\n", - " if action.function.name != \"remove_function\":\n", - " function_args = json.loads(action.function.arguments.strip('\"'))\n", - " code = function_args.get(\"code\")\n", - " try:\n", - " compile(code, \"\", \"exec\")\n", - " print(\"successfully compiled\")\n", - " except SyntaxError as e:\n", - " print(\"Syntax is invalid:\", e)\n", - " return False\n", - " return True\n", - "\n", - " def _format_actions(self, actions):\n", - " ans = []\n", - " for action in actions:\n", - " func = json.loads(action.function.arguments.strip('\"'))\n", - " func[\"action_name\"] = action.function.name\n", - "\n", - " if func.get(\"action_name\") == \"remove_function\":\n", - " item = {\n", - " \"action_name\": func.get(\"action_name\"),\n", - " \"name\": func.get(\"name\"),\n", - " }\n", - " else:\n", - " item = {\n", - " \"action_name\": func.get(\"action_name\"),\n", - " \"name\": func.get(\"name\"),\n", - " \"description\": func.get(\"description\"),\n", - " \"arguments\": json.loads(func.get(\"arguments\").strip('\"')),\n", - " \"packages\": func.get(\"packages\"),\n", - " \"code\": func.get(\"code\"),\n", - " }\n", - " ans.append(item)\n", - " return ans\n", - "\n", - " def _get_success_rate(self, statistic):\n", - " sum = 0\n", - " for key, value in statistic.items():\n", - " if \"is_correct\" not in value.keys():\n", - " statistic[key][\"is_correct\"] = 0\n", - " for key, value in statistic.items():\n", - " sum += value[\"is_correct\"]\n", - " if len(statistic.keys()) != 0:\n", - " success_rate = sum / len(statistic.keys())\n", - " else:\n", - " success_rate = None\n", - " return success_rate, statistic\n", - "\n", - " def _modify_function_signiture(self, cur_functions, action_json):\n", - " for action in action_json:\n", - " action_name = action.get(\"action_name\")\n", - " if action_name != \"remove_function\":\n", - " cur_functions = [item for item in cur_functions if item[\"name\"] != action.get(\"name\")]\n", - " cur_functions.append(\n", - " {\n", - " \"name\": action.get(\"name\"),\n", - " \"description\": action.get(\"description\"),\n", - " \"arguments\": action.get(\"arguments\"),\n", - " \"packages\": action.get(\"packages\"),\n", - " \"code\": action.get(\"code\"),\n", - " }\n", - " )\n", - " else:\n", - " cur_functions = [item for item in cur_functions if item[\"name\"] != action.get(\"name\")]\n", - " return cur_functions\n", - "\n", - " def update_function_call(self, action, mathproxyagent, assistant):\n", - " def execute_func(name, packages, code, **args):\n", - " pip_install = (\n", - " f\"\"\"print(\"Installing package: {packages}\")\\nsubprocess.run([\"pip\", \"-qq\", \"install\", \"{packages}\"])\"\"\"\n", - " if packages\n", - " else \"\"\n", - " )\n", - " str = f\"\"\"\n", - "import subprocess\n", - "{pip_install}\n", - "print(\"Result of {name} function execution:\")\n", - "{code}\n", - "args={args}\n", - "result={name}(**args)\n", - "if result is not None: print(result)\n", - "\"\"\"\n", - " print(f\"execute_code:\\n{str}\")\n", - " result = execute_code(str)\n", - " if result[0] != 0:\n", - " raise Exception(\"Error in executing function:\" + result[1])\n", - " print(f\"Result: {result[1]}\")\n", - " return result[1]\n", - "\n", - " name, description, arguments, packages, code, action_name = (\n", - " action.get(\"name\"),\n", - " action.get(\"description\"),\n", - " action.get(\"arguments\"),\n", - " action.get(\"packages\"),\n", - " action.get(\"code\"),\n", - " action.get(\"action_name\"),\n", - " )\n", - "\n", - " if name in mathproxyagent._function_map.keys():\n", - " del mathproxyagent._function_map[name]\n", - " if action_name != \"remove_function\":\n", - " function_config = {\n", - " \"name\": name,\n", - " \"description\": description,\n", - " \"parameters\": {\"type\": \"object\", \"properties\": arguments},\n", - " }\n", - " mathproxyagent.register_function(\n", - " function_map={name: lambda **args: execute_func(name, packages, code, **args)}\n", - " )\n", - " assistant.update_function_signature(function_config, is_remove=False)\n", - " else:\n", - " assistant.update_function_signature(name, is_remove=True)\n", - "\n", - " def step(self, history, statistic, func_signiture):\n", - " action_return = []\n", - " origin_signiture = func_signiture\n", - " modified_signiture = origin_signiture\n", - "\n", - " success_rate, statistic = self._get_success_rate(statistic) # TODO: make statistic feasible outside of the loop\n", - " for action_index in range(self._action_num):\n", - " prompt = self.OPT_PROMPT.format(\n", - " conversation_num=len(history),\n", - " statistic={\"is_correct\": statistic},\n", - " signiture=origin_signiture,\n", - " history=history,\n", - " success_rate=success_rate,\n", - " actions_num=action_index,\n", - " after_signiture=modified_signiture,\n", - " )\n", - " messages = [{\"role\": \"user\", \"content\": prompt}]\n", - " for _ in range(self._each_action_max_trials):\n", - " response = self._client.chat.completions.create(\n", - " model=self.model,\n", - " messages=messages,\n", - " tools=[self.ADD_FUNC, self.REVISE_FUNC, self.REMOVE_FUNC],\n", - " tool_choice=\"auto\",\n", - " )\n", - " actions = response.choices[0].message.tool_calls\n", - " if (\n", - " self._val_json(actions)\n", - " and self._val_syntax(actions)\n", - " and self._val_remove(actions, modified_signiture)\n", - " ):\n", - " break\n", - " if actions is not None:\n", - " action_result = self._format_actions(actions)\n", - " action_return = action_return + action_result\n", - " modified_signiture = self._modify_function_signiture(modified_signiture, action_result)\n", - " return action_return, modified_signiture" + "from autogen.agentchat import Agent\n", + "from openai import BadRequestError\n", + "from autogen.code_utils import extract_code\n", + "from autogen.math_utils import get_answer\n", + "from autogen import config_list_from_json\n", + "import autogen\n", + "import json\n", + "import copy" ] }, { @@ -379,18 +50,19 @@ "source": [ "# MathUserProxy with function_call\n", "\n", - "This agent is a customozied MathUserProxy inherits from its [partent class](https://github.com/microsoft/autogen/blob/main/autogen/agentchat/contrib/math_user_proxy_agent.py.) \n", + "This agent is a customozied MathUserProxy inherits from its [partent class](https://github.com/microsoft/autogen/blob/main/autogen/agentchat/contrib/math_user_proxy_agent.py.).\n", "\n", - "It supports using both function_call and python to solve math problems." + "It supports using both function_call and python to solve math problems.\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def is_termination_msg_mathchat(message):\n", + " \"\"\"Check if a message is a termination message.\"\"\"\n", " if isinstance(message, dict):\n", " message = message.get(\"content\")\n", " if message is None:\n", @@ -398,12 +70,11 @@ " cb = extract_code(message)\n", " contain_code = False\n", " for c in cb:\n", - " if c[0] == \"python\" or c[0] == \"wolfram\":\n", + " if c[0] == \"python\":\n", " contain_code = True\n", " break\n", " if message.rstrip().find(\"TERMINATE\") >= 0:\n", " return True\n", - "\n", " return not contain_code and get_answer(message) is not None and get_answer(message) != \"\"\n", "\n", "\n", @@ -471,6 +142,7 @@ " self.max_function_call_trial = 3\n", " self.query = None\n", " self.answer = None\n", + " self.is_correct = None\n", "\n", " def generate_function_call_reply(\n", " self,\n", @@ -490,7 +162,6 @@ " else:\n", " if self.max_function_call_trial == 0:\n", " error_message = func_return[\"content\"]\n", - " self.logs[\"is_correct\"] = 0\n", " self.max_function_call_trial = 3\n", " return (\n", " True,\n", @@ -509,12 +180,11 @@ " def initiate_chat(\n", " self,\n", " recipient,\n", - " query: None,\n", " answer: None,\n", " silent: Optional[bool] = False,\n", " **context,\n", " ):\n", - " self.query = query\n", + " self.query = context[\"problem\"]\n", " if not isinstance(answer, str):\n", " answer = str(answer)\n", " if answer.endswith(\".0\"):\n", @@ -522,29 +192,23 @@ " self._answer = answer\n", " else:\n", " self._answer = answer\n", - " self.logs = {}\n", - " self._prepare_chat(recipient, True)\n", "\n", - " chat_history = []\n", - " error_message = None\n", + " self.is_correct = None\n", "\n", + " self._prepare_chat(recipient, True)\n", + " error_message = None\n", " try:\n", " prompt = self.PROMPTS + context[\"problem\"]\n", " self.send(prompt, recipient, silent=silent)\n", " except BadRequestError as e:\n", " error_message = str(e)\n", - " self.logs[\"is_correct\"] = 0\n", + " self.is_correct = 0\n", " print(\"error information: {}\".format(error_message))\n", "\n", - " key = list(self.chat_messages.keys())[0]\n", - " chat_messages = self.chat_messages[key]\n", - " for item in chat_messages:\n", - " chat_history.append(item)\n", - " if error_message is not None:\n", - " chat_history.append(error_message)\n", " recipient.reset()\n", - " self.reset()\n", - " return self.logs, chat_history\n", + " is_correct = copy.deepcopy(self.is_correct)\n", + " self._reset()\n", + " return is_correct\n", "\n", " def _check_final_result(\n", " self,\n", @@ -552,8 +216,8 @@ " sender: Optional[autogen.Agent] = None,\n", " config: Optional[Any] = None,\n", " ):\n", - " messages = messages[-1]\n", "\n", + " messages = messages[-1]\n", " if isinstance(messages, dict):\n", " messages = messages.get(\"content\")\n", " if messages is None:\n", @@ -562,29 +226,25 @@ " cb = extract_code(messages)\n", " contain_code = False\n", " for c in cb:\n", - " if c[0] == \"python\" or c[0] == \"wolfram\":\n", + " if c[0] == \"python\":\n", " contain_code = True\n", " break\n", " if not contain_code and get_answer(messages) is not None and get_answer(messages) != \"\":\n", " if get_answer(messages) == self._answer:\n", - " self.logs[\"is_correct\"] = 1\n", + " self.is_correct = 1\n", " return True, \"The result is Correct. Please reply me with TERMINATE.\"\n", " else:\n", - " self.logs[\"is_correct\"] = 0\n", + " self.is_correct = 0\n", " return False, None\n", " else:\n", " return False, None\n", "\n", " def _reset(self):\n", - " self._valid_q_count = 0\n", - " self._total_q_count = 0\n", - " self._accum_invalid_q_per_step = 0\n", - " self._previous_code = \"\"\n", - " self.last_reply = None\n", - "\n", + " super()._reset()\n", + " self.max_function_call_trial = 3\n", + " self.is_correct = None\n", " self.query = None\n", - " self.answer = None\n", - " self.logs = {}" + " self.answer = None" ] }, { @@ -595,12 +255,12 @@ "\n", "MATAH dataset contains 12,500 challenging competition mathematics problems. Each problem in MATH has a full step-by-step solution which can be used to teach models to generate answer derivations and explanations. \n", "\n", - "We strctly follow the train/test splits of [Craft](https://github.com/lifan-yuan/CRAFT). Please specific your own path to the dataset. Here we sample the first 10 algebra problems as examples. " + "We strctly follow the [train](https://github.com/lifan-yuan/CRAFT/blob/main/tab_and_math/MATH/dataset/train/algebra.jsonl)/[test](https://github.com/lifan-yuan/CRAFT/blob/main/tab_and_math/MATH/dataset/algebra.jsonl) splits of [Craft](https://github.com/lifan-yuan/CRAFT). Please specific your own path to the dataset. Here we sample the first 10 algebra problems as examples. " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -625,19 +285,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ - "config_list = autogen.config_list_from_json(\n", - " \"OAI_CONFIG_LIST\",\n", - ")\n", - "mathproxyagent = MathUserProxyAgent(\n", - " name=\"mathproxyagent\",\n", - " human_input_mode=\"NEVER\",\n", - " code_execution_config={\"work_dir\": \"_output\", \"use_docker\": False},\n", - " is_termination_msg=is_termination_msg_mathchat,\n", - ")\n", + "config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n", + "\n", "assistant = autogen.AssistantAgent(\n", " name=\"assistant\",\n", " system_message=\"You are a helpful assistant.\",\n", @@ -646,6 +299,11 @@ " \"seed\": 42,\n", " \"config_list\": config_list,\n", " },\n", + ")\n", + "user_proxy = MathUserProxyAgent(\n", + " name=\"mathproxyagent\",\n", + " human_input_mode=\"NEVER\",\n", + " code_execution_config={\"work_dir\": \"_output\", \"use_docker\": False},\n", ")" ] }, @@ -666,32 +324,22 @@ "metadata": {}, "outputs": [], "source": [ - "history_list, statistic_list = {}, {}\n", - "for index, query in enumerate(test_data):\n", - " single_statistic, chat_history = mathproxyagent.initiate_chat(\n", - " recipient=assistant, answer=query[\"answer\"], query=[\"question\"], problem=query[\"question\"]\n", - " )\n", - " history_list[\"conversation: {index}\".format(index=index + 1)] = chat_history\n", - " statistic_list[\"conversation: {index}\".format(index=index + 1)] = single_statistic\n", - "\n", "sum = 0\n", - "for key, value in statistic_list.items():\n", - " if \"is_correct\" not in value.keys():\n", - " statistic_list[key][\"is_correct\"] = 0\n", - "for key, value in statistic_list.items():\n", - " sum += value[\"is_correct\"]\n", - "\n", - "success_rate_without_agent_training = sum / len(statistic_list.keys())" + "for index, query in enumerate(test_data):\n", + " is_correct = user_proxy.initiate_chat(recipient=assistant, answer=query[\"answer\"], problem=query[\"question\"])\n", + " print(is_correct)\n", + " sum += is_correct\n", + "success_rate_without_agent_training = sum / 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Agent optimizing \n", - "\n", - "Then, use the AgentOptimizer to iteratively optimize the agents by optimizing the function calls according to the historical conversations and performance. \n", + "# Agent Training \n", "\n", + "Then, we use the AgentOptimizer to iteratively optimize the agents by optimizing the function calls according to the historical conversations and performance.\n", + "The AgentOptimizer yields register_for_llm and register_for_executor at each iteration, which are subsequently utilized to update the assistant and user_proxy agents, respectively. \n", "Here we optimize these two agents for ten epochs. " ] }, @@ -702,42 +350,20 @@ "outputs": [], "source": [ "EPOCH = 10\n", - "OAI_config = {\n", - " \"AZURE_OPENAI_API_KEY\": \"gpt-4-1106-preview\",\n", - " \"api_version\": \"2024-02-15-preview\",\n", - " \"azure_endpoint\": \"your_azure_endpoint\",\n", - " \"model\": \"your_model\",\n", - "}\n", - "agent_optimizer = AgentOptimizer(OAI_config=OAI_config)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "history_list, statistic_list, function_json, agent_list = [], [], [], []\n", - "for epoch in range(EPOCH):\n", - " if len(history_list) != 0:\n", - " actions, function_json = agent_optimizer.step(history_list, statistic_list, function_json)\n", - " for action in actions:\n", - " agent_optimizer.update_function_call(action, mathproxyagent=mathproxyagent, assistant=assistant)\n", - " history_list, statistic_list = {}, {}\n", + "optimizer_model = \"gpt-4-1106-preview\"\n", + "optimizer = AgentOptimizer(\n", + " max_actions_per_step=3, config_file_or_env=\"OAI_CONFIG_LIST\", optimizer_model=optimizer_model\n", + ")\n", + "for i in range(EPOCH):\n", " for index, query in enumerate(train_data):\n", - " single_statistic, chat_history = mathproxyagent.initiate_chat(\n", - " recipient=assistant, answer=query[\"answer\"], query=[\"question\"], problem=query[\"question\"]\n", - " )\n", - " history_list[\"conversation: {index}\".format(index=index + 1)] = chat_history\n", - " statistic_list[\"conversation: {index}\".format(index=index + 1)] = single_statistic\n", - "\n", - " sum = 0\n", - " for key, value in statistic_list.items():\n", - " if \"is_correct\" not in value.keys():\n", - " statistic_list[key][\"is_correct\"] = 0\n", - " for key, value in statistic_list.items():\n", - " sum += value[\"is_correct\"]\n", - " print(\"Train_Epoch_{epoch_num}_Success_Rate: {average}%\".format(epoch_num=epoch, average=sum / len(statistic_list)))" + " is_correct = user_proxy.initiate_chat(assistant, answer=query[\"answer\"], problem=query[\"question\"])\n", + " history = assistant.chat_messages_for_summary(user_proxy)\n", + " optimizer.record_one_conversation(history, is_satisfied=is_correct)\n", + " register_for_llm, register_for_exector = optimizer.step()\n", + " for item in register_for_llm:\n", + " assistant.update_function_signature(**item)\n", + " if len(register_for_exector.keys()) > 0:\n", + " user_proxy.register_function(function_map=register_for_exector)" ] }, { @@ -758,27 +384,16 @@ "metadata": {}, "outputs": [], "source": [ - "history_list, statistic_list = {}, {}\n", - "for index, query in enumerate(test_data):\n", - " single_statistic, chat_history = mathproxyagent.initiate_chat(\n", - " recipient=assistant, answer=query[\"answer\"], query=[\"question\"], problem=query[\"question\"]\n", - " )\n", - " history_list[\"conversation: {index}\".format(index=index + 1)] = chat_history\n", - " statistic_list[\"conversation: {index}\".format(index=index + 1)] = single_statistic\n", - "\n", "sum = 0\n", - "for key, value in statistic_list.items():\n", - " if \"is_correct\" not in value.keys():\n", - " statistic_list[key][\"is_correct\"] = 0\n", - "for key, value in statistic_list.items():\n", - " sum += value[\"is_correct\"]\n", - "\n", - "success_rate_with_agent_training = sum / len(statistic_list.keys())" + "for index, query in enumerate(test_data):\n", + " is_correct = user_proxy.initiate_chat(recipient=assistant, answer=query[\"answer\"], problem=query[\"question\"])\n", + " sum += is_correct\n", + "success_rate_with_agent_training = sum / 10" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -786,9 +401,12 @@ "output_type": "stream", "text": [ "------------------------------------------------Functions learned------------------------------------------------\n", - "{'name': 'evaluate_expression', 'description': 'Evaluate arithmetic or mathematical expressions provided as strings.', 'arguments': {'expression': {'type': 'string', 'description': 'The mathematical expression to evaluate.'}}, 'packages': 'sympy', 'code': 'from sympy import sympify, SympifyError\\n\\ndef evaluate_expression(expression):\\n try:\\n result = sympify(expression)\\n if result.is_number:\\n result = float(result)\\n else:\\n result = str(result)\\n return result\\n except SympifyError as e:\\n return str(e)'}\n", - "{'name': 'calculate_compound_interest_principal', 'description': 'Calculate the principal amount needed to achieve a certain future value with quarterly compound interest.', 'arguments': {'future_value': {'type': 'number', 'description': 'The total amount of money desired in the future.'}, 'annual_interest_rate': {'type': 'number', 'description': 'The annual interest rate in decimal form.'}, 'compounding_periods': {'type': 'integer', 'description': 'The number of times interest is compounded per year.'}, 'years': {'type': 'number', 'description': 'The time in years the money is invested for.'}}, 'packages': 'sympy', 'code': \"from sympy import symbols, solve, N\\n\\nfuture_value, annual_interest_rate, compounding_periods, years = symbols('future_value annual_interest_rate compounding_periods years')\\nP = symbols('P')\\nequation = Eq(future_value, P * (1 + annual_interest_rate / compounding_periods) ** (compounding_periods * years))\\nprincipal = solve(equation, P)[0]\\nprincipal = N(principal, chop=True).evalf()\"}\n", - "{'name': 'solve_linear_system', 'description': 'Solve a system of linear equations represented as coefficients and variables.', 'arguments': {'equations': {'type': 'array', 'items': {'type': 'array', 'items': {'type': 'number'}}}, 'variables': {'type': 'array', 'items': {'type': 'string'}}}, 'packages': 'sympy', 'code': 'from sympy import Matrix, symbols, linsolve\\n\\ndef solve_linear_system(equations, variables):\\n if not equations or not all(len(eq) == len(variables) + 1 for eq in equations):\\n return \"Error: Equations list is empty or not all equations have the correct length.\"\\n if len(variables) != len(set(variables)):\\n return \"Error: Duplicate symbols in the \\'variables\\' list.\"\\n try:\\n sym_vars = symbols(\\' \\'.join(variables))\\n matrix = Matrix(equations)\\n system = (matrix, sym_vars)\\n solution = linsolve(system)\\n solution_list = list(solution)\\n if not solution_list:\\n return \"Error: No solution exists for the given system of equations.\"\\n result = solution_list[0] if solution_list else []\\n return dict(zip(variables, result))\\n except Exception as e:\\n return f\"Error: {str(e)}\"'}\n", + "evaluate_expression: Evaluate arithmetic or mathematical expressions provided as strings.\n", + "\n", + "calculate_compound_interest_principal: Calculate the principal amount needed to achieve a certain future value with quarterly compound interest.\n", + "\n", + "solve_linear_system: Solve a system of linear equations represented as coefficients and variables.\n", + "\n", "------------------------------------------------Summary------------------------------------------------\n", "\n", "success_rate_without_agent_training: 60.0%\n", @@ -802,8 +420,8 @@ "print(\n", " \"------------------------------------------------Functions learned------------------------------------------------\"\n", ")\n", - "for function in function_json:\n", - " print(function)\n", + "for func in assistant.llm_config[\"functions\"]:\n", + " print(func[\"name\"] + \": \" + func[\"description\"] + \"\\n\")\n", "print(\"------------------------------------------------Summary------------------------------------------------\\n\")\n", "print(\"success_rate_without_agent_training: {average}%\\n\".format(average=success_rate_without_agent_training * 100))\n", "print(\"success_rate_with_agent_training: {average}%\\n\".format(average=success_rate_with_agent_training * 100))" @@ -826,7 +444,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/test/agentchat/contrib/test_agent_optimizer.py b/test/agentchat/contrib/test_agent_optimizer.py new file mode 100644 index 000000000000..fe0ba2dc289b --- /dev/null +++ b/test/agentchat/contrib/test_agent_optimizer.py @@ -0,0 +1,110 @@ +import pytest +import os +from conftest import skip_openai as skip + +import autogen +from test_assistant_agent import OAI_CONFIG_LIST, KEY_LOC +from autogen import AssistantAgent, UserProxyAgent, config_list_from_json + +from autogen.agentchat.contrib.agent_optimizer import AgentOptimizer + +here = os.path.abspath(os.path.dirname(__file__)) + + +@pytest.mark.skipif( + skip, + reason="requested to skip", +) +def test_record_conversation(): + problem = "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}" + + config_list = autogen.config_list_from_json( + OAI_CONFIG_LIST, + file_location=KEY_LOC, + ) + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + llm_config={ + "timeout": 60, + "cache_seed": 42, + "config_list": config_list, + }, + ) + user_proxy = UserProxyAgent( + name="user_proxy", + human_input_mode="NEVER", + code_execution_config={ + "work_dir": f"{here}/test_agent_scripts", + "use_docker": "python:3", + "timeout": 60, + }, + max_consecutive_auto_reply=3, + ) + + user_proxy.initiate_chat(assistant, message=problem) + optimizer = AgentOptimizer(max_actions_per_step=3, config_file_or_env=OAI_CONFIG_LIST) + optimizer.record_one_conversation(assistant.chat_messages_for_summary(user_proxy), is_satisfied=True) + + assert len(optimizer._trial_conversations_history) == 1 + assert len(optimizer._trial_conversations_performance) == 1 + assert optimizer._trial_conversations_performance[0]["Conversation 0"] == 1 + + optimizer.reset_optimizer() + assert len(optimizer._trial_conversations_history) == 0 + assert len(optimizer._trial_conversations_performance) == 0 + + +@pytest.mark.skipif( + skip, + reason="requested to skip", +) +def test_step(): + problem = "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}" + + config_list = autogen.config_list_from_json( + OAI_CONFIG_LIST, + file_location=KEY_LOC, + ) + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + llm_config={ + "timeout": 60, + "cache_seed": 42, + "config_list": config_list, + }, + ) + user_proxy = UserProxyAgent( + name="user_proxy", + human_input_mode="NEVER", + code_execution_config={ + "work_dir": f"{here}/test_agent_scripts", + "use_docker": "python:3", + "timeout": 60, + }, + max_consecutive_auto_reply=3, + ) + + optimizer = AgentOptimizer(max_actions_per_step=3, config_file_or_env=OAI_CONFIG_LIST) + user_proxy.initiate_chat(assistant, message=problem) + optimizer.record_one_conversation(assistant.chat_messages_for_summary(user_proxy), is_satisfied=True) + + register_for_llm, register_for_exector = optimizer.step() + + print("-------------------------------------") + print("register_for_llm:") + print(register_for_llm) + print("register_for_exector") + print(register_for_exector) + + for item in register_for_llm: + assistant.update_function_signature(**item) + if len(register_for_exector.keys()) > 0: + user_proxy.register_function(function_map=register_for_exector) + + print("-------------------------------------") + print("Updated assistant.llm_config:") + print(assistant.llm_config) + print("Updated user_proxy._function_map:") + print(user_proxy._function_map) diff --git a/website/blog/2023-12-23-AgentOptimizer/index.mdx b/website/blog/2023-12-23-AgentOptimizer/index.mdx index b59046d55c40..23e348fa30cb 100644 --- a/website/blog/2023-12-23-AgentOptimizer/index.mdx +++ b/website/blog/2023-12-23-AgentOptimizer/index.mdx @@ -11,57 +11,83 @@ tags: [LLM, research] **TL;DR:** Introducing **AgentOptimizer**, a new class for training LLM agents in the era of LLMs as a service. -**AgentOptimizer** is able to prompt autogen agents to iteratively optimize its function/skills according to the historical conversation and performance. +**AgentOptimizer** is able to prompt LLMs to iteratively optimize function/skills of AutoGen agents according to the historical conversation and performance. Checkout one implementation for **AgentOptimizer** on [MATH](https://github.com/hendrycks/math) dataset [here](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_agentoptimizer.ipynb). -Paper is coming soon! +More information could be found in the [paper](https://arxiv.org/abs/2402.11359). ## Introduction -In traditional ML pipeline, we train a model by updating its weights according to the loss on the training set, while in the era of LLM agents, how should we train an agent? -Here, we take an initial step towards the agent training. -Inspired by the [function calling](https://platform.openai.com/docs/guides/function-calling) capabilities provided by OpenAI, +In the traditional ML pipeline, we train a model by updating its weights according to the loss on the training set, while in the era of LLM agents, how should we train an agent? +Here, we take an initial step towards the agent training. Inspired by the [function calling](https://platform.openai.com/docs/guides/function-calling) capabilities provided by OpenAI, we draw an analogy between model weights and agent functions/skills, and update an agent’s functions/skills based on its historical performance on a training set. -Specifically, we propose to use the function calling capabilities to formulate the actions that optimize the agents’ functions as a set of function calls, -to support iteratively **adding, revising, and removing** existing functions. -As an agentic way of training an agent, our approach helps enhance the agents’ abilities without requiring access to the LLMs weights. +Specifically, we propose to use the function calling capabilities to formulate the actions that optimize the agents’ functions as a set of function calls, to support iteratively **adding, revising, and removing** existing functions. +We also include two strategies, roll-back, and early-stop, to streamline the training process to overcome the performance-decreasing problem when training. +As an agentic way of training an agent, our approach helps enhance the agents’ abilities without requiring access to the LLM's weights. ## AgentOptimizer **AgentOptimizer** is a class designed to optimize the agents by improving their function calls. -It contains two core methods: +It contains three main methods: -1. `step()`: `step()` takes three inputs, including the previous conversation history (history), the statistical information of solving previous problems (statistic), and the current functions (current_functions). +1. `record_one_conversation`: + +This method records the conversation history and performance of the agents in solving one problem. +It includes two inputs: conversation_history (List[Dict]) and is_satisfied (bool). +conversation_history is a list of dictionaries which could be got from chat_messages_for_summary in the [AgentChat](https://microsoft.github.io/autogen/docs/reference/agentchat/agentchat/) class. +is_satisfied is a bool value that represents whether the user is satisfied with the solution. If it is none, the user will be asked to input the satisfaction. + +Example: + +```python +optimizer = AgentOptimizer(max_actions_per_step=3, config_file_or_env="OAI_CONFIG_LIST") +# ------------ code to solve a problem ------------ +# ...... +# ------------------------------------------------- +history = assistant.chat_messages_for_summary(UserProxy) +optimizer.record_one_conversation(history, is_satisfied=result) +``` + + +2. `step()`: + +`step()` is the core method of AgentOptimizer. +At each optimization iteration, it will return two fields register_for_llm and register_for_executor, which are subsequently utilized to update the assistant and UserProxy agents, respectively. ```python -actions, updated_functions = AgentOptimizer.step(history, statistic, current_functions) +register_for_llm, register_for_exector = optimizer.step() +for item in register_for_llm: + assistant.update_function_signature(**item) +if len(register_for_exector.keys()) > 0: + user_proxy.register_function(function_map=register_for_exector) ``` -It has two outputs `actions` and `updated_functions`. `actions` is a series of actions to manipulate the current functions. And `updated_functions` is the updated functions after the actions are applied (including code implementation). +3. `reset_optimizer`: -2. `update_function_call()`: -This method takes the agents and actions as input. It updates the functions registered in these agents according to the actions from `step()`. -For AssistantAgent, it first uses [update_function_signature](https://microsoft.github.io/autogen/docs/reference/agentchat/conversable_agent/#update_function_signature) to update the function signatures. -Then, it updates the functions in the MathUserproxyAgent with the corresponding code implementation gained from `step()`. +This method will reset the optimizer to the initial state, which is useful when you want to train the agent from scratch. -Sometimes, the function signatures (JSON schema) returned by the `step()` may not be valid, and the generated code may also face syntax errors. -**AgentOptimizer** includes mechanisms to check the (1) validity of the function signatures and (2) code implementation before updating the functions. -Moreover, it also includes mechanisms to check whether each action is feasible, such as avoiding the removal of a function that is not in the current functions due to hallucination. +**AgentOptimizer** includes mechanisms to check the (1) validity of the function and (2) code implementation before returning the register_for_llm, register_for_exector. +Moreover, it also includes mechanisms to check whether each update is feasible, such as avoiding the removal of a function that is not in the current functions due to hallucination. ## Pseudocode for the optimization process The optimization process is as follows: ```python -for - in range(EPOCH): - history, statistic, current_functions = solve_problems(train_problems) - actions, updated_functions = AgentOptimizer.step(history, statistic, current_functions) - AgentOptimizer.update_function_call(actions) +optimizer = AgentOptimizer(max_actions_per_step=3, config_file_or_env="OAI_CONFIG_LIST") +for i in range(EPOCH): + is_correct = user_proxy.initiate_chat(assistant, message = problem) + history = assistant.chat_messages_for_summary(user_proxy) + optimizer.record_one_conversation(history, is_satisfied=is_correct) + register_for_llm, register_for_exector = optimizer.step() + for item in register_for_llm: + assistant.update_function_signature(**item) + if len(register_for_exector.keys()) > 0: + user_proxy.register_function(function_map=register_for_exector) ``` Given a prepared training dataset, the agents iteratively solve problems from the training set to obtain conversation history and statistical information. -The functions are then improved using AgentOptimizer. -Each iteration can be regarded as one training step analogous to traditional machine learning, with the optimization elements being the functions that agents have. +The functions are then improved using AgentOptimizer. Each iteration can be regarded as one training step analogous to traditional machine learning, with the optimization elements being the functions that agents have. After EPOCH iterations, the agents are expected to obtain better functions that may be used in future tasks @@ -71,7 +97,7 @@ To obtain stable and structured function signatures and code implementations fro we leverage the function calling capabilities provided by OpenAI to formulate the actions that manipulate the functions as a set of function calls. Specifically, we introduce three function calls to manipulate the current functions at each step: `add_function`, `remove_function`, and `revise_function`. These calls add, remove, and revise functions in the existing function list, respectively. -This practice could fully leverages the function calling capabilities of GPT-4 and outputs structured functions with more stable signatures and code implementation. +This practice could fully leverage the function calling capabilities of GPT-4 and output structured functions with more stable signatures and code implementation. Below is the JSON schema of these function calls: 1. `add_function`: Add one new function that may be used in the future tasks. @@ -84,30 +110,24 @@ ADD_FUNC = { "parameters": { "type": "object", "properties": { - "name": { - "type": "string", - "description": "The name of the function in the code implementation." - }, - "description": { - "type": "string", - "description": "A short description of the function." - }, + "name": {"type": "string", "description": "The name of the function in the code implementation."}, + "description": {"type": "string", "description": "A short description of the function."}, "arguments": { "type": "string", - "description": "JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { \"url\": { \"type\": \"string\", \"description\": \"The URL\", }}. Please avoid the error 'array schema missing items' when using array type." + "description": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { "url": { "type": "string", "description": "The URL", }}. Please avoid the error \'array schema missing items\' when using array type.', }, "packages": { "type": "string", - "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list." + "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.", }, "code": { "type": "string", - "description": "The implementation in Python. Do not include the function declaration." - } + "description": "The implementation in Python. Do not include the function declaration.", + }, }, - "required": ["name", "description", "arguments", "packages", "code"] - } - } + "required": ["name", "description", "arguments", "packages", "code"], + }, + }, } ``` @@ -121,30 +141,24 @@ REVISE_FUNC = { "parameters": { "type": "object", "properties": { - "name": { - "type": "string", - "description": "The name of the function in the code implementation." - }, - "description": { - "type": "string", - "description": "A short description of the function." - }, + "name": {"type": "string", "description": "The name of the function in the code implementation."}, + "description": {"type": "string", "description": "A short description of the function."}, "arguments": { "type": "string", - "description": "JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { \"url\": { \"type\": \"string\", \"description\": \"The URL\", }}. Please avoid the error 'array schema missing items' when using array type." + "description": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { "url": { "type": "string", "description": "The URL", }}. Please avoid the error \'array schema missing items\' when using array type.', }, "packages": { "type": "string", - "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list." + "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.", }, "code": { "type": "string", - "description": "The implementation in Python. Do not include the function declaration." - } + "description": "The implementation in Python. Do not include the function declaration.", + }, }, - "required": ["name", "description", "arguments", "packages", "code"] - } - } + "required": ["name", "description", "arguments", "packages", "code"], + }, + }, } ``` @@ -158,22 +172,16 @@ REMOVE_FUNC = { "parameters": { "type": "object", "properties": { - "name": { - "type": "string", - "description": "The name of the function in the code implementation." - } + "name": {"type": "string", "description": "The name of the function in the code implementation."} }, - "required": ["name"] - } - } + "required": ["name"], + }, + }, } ``` ## Limitation & Future work -1. Unlike gradient descent in traditional machine learning training processes, each optimization step does not necessarily lead to better performance on the training set. -When the training epoch is small, the agent’s performance may even decrease. One urgent task is to design a better mechanism to guide the optimization process. -2. The [Current implementation](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_agentoptimizer.ipynb) of AgentOptimizer is mainly for illustration purpose and is just a proof of concept. -It is not formally integrated into the autogen with a general interface like optimizing any kinds of agents in any tasks. -Currently, it only supports optimizing the multi-agent system in solving problems from [MATH](https://github.com/hendrycks/math) dataset. We will integrate it into autogen with more general interface in the future. +1. Currently, it only supports optimizing the one typical user_proxy and assistant agents pair. We will make this feature more general to support other agent types in future work. +2. The current implementation of AgentOptimizer is effective solely on the OpenAI GPT-4 model. Extending this feature/concept to other LLMs is the next step.