Skip to content

Commit

Permalink
Separate openai assistant related config items from llm_config (micro…
Browse files Browse the repository at this point in the history
…soft#2037)

* add assistant config

* add test

* change notebook to use assistant config

* use assistant config in testing

* code refinement

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
IANTHEREAL and ekzhu authored Mar 16, 2024
1 parent d4582ad commit edcb635
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 183 deletions.
69 changes: 51 additions & 18 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
name="GPT Assistant",
instructions: Optional[str] = None,
llm_config: Optional[Union[Dict, bool]] = None,
assistant_config: Optional[Dict] = None,
overwrite_instructions: bool = False,
overwrite_tools: bool = False,
**kwargs,
Expand All @@ -43,8 +44,9 @@ def __init__(
AssistantAgent.DEFAULT_SYSTEM_MESSAGE. If the assistant exists, the
system message will be set to the existing assistant instructions.
llm_config (dict or False): llm inference configuration.
- assistant_id: ID of the assistant to use. If None, a new assistant will be created.
- model: Model to use for the assistant (gpt-4-1106-preview, gpt-3.5-turbo-1106).
assistant_config
- assistant_id: ID of the assistant to use. If None, a new assistant will be created.
- check_every_ms: check thread run status interval
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
Expand All @@ -57,23 +59,19 @@ def __init__(
"""

self._verbose = kwargs.pop("verbose", False)
openai_client_cfg, openai_assistant_cfg = self._process_assistant_config(llm_config, assistant_config)

super().__init__(
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=llm_config, **kwargs
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=openai_client_cfg, **kwargs
)

if llm_config is False:
raise ValueError("llm_config=False is not supported for GPTAssistantAgent.")
# Use AutooGen OpenAIWrapper to create a client
openai_client_cfg = copy.deepcopy(llm_config)
# Use the class variable
model_name = GPTAssistantAgent.DEFAULT_MODEL_NAME

# GPTAssistantAgent's azure_deployment param may cause NotFoundError (404) in client.beta.assistants.list()
# See: https://github.com/microsoft/autogen/pull/1721
model_name = self.DEFAULT_MODEL_NAME
if openai_client_cfg.get("config_list") is not None and len(openai_client_cfg["config_list"]) > 0:
model_name = openai_client_cfg["config_list"][0].pop("model", GPTAssistantAgent.DEFAULT_MODEL_NAME)
model_name = openai_client_cfg["config_list"][0].pop("model", self.DEFAULT_MODEL_NAME)
else:
model_name = openai_client_cfg.pop("model", GPTAssistantAgent.DEFAULT_MODEL_NAME)
model_name = openai_client_cfg.pop("model", self.DEFAULT_MODEL_NAME)

logger.warning("OpenAI client config of GPTAssistantAgent(%s) - model: %s", name, model_name)

Expand All @@ -82,14 +80,17 @@ def __init__(
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")

self._openai_client = oai_wrapper._clients[0]._oai_client
openai_assistant_id = llm_config.get("assistant_id", None)
openai_assistant_id = openai_assistant_cfg.get("assistant_id", None)
if openai_assistant_id is None:
# try to find assistant by name first
candidate_assistants = retrieve_assistants_by_name(self._openai_client, name)
if len(candidate_assistants) > 0:
# Filter out candidates with the same name but different instructions, file IDs, and function names.
candidate_assistants = self.find_matching_assistant(
candidate_assistants, instructions, llm_config.get("tools", []), llm_config.get("file_ids", [])
candidate_assistants,
instructions,
openai_assistant_cfg.get("tools", []),
openai_assistant_cfg.get("file_ids", []),
)

if len(candidate_assistants) == 0:
Expand All @@ -103,9 +104,9 @@ def __init__(
self._openai_assistant = self._openai_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
tools=openai_assistant_cfg.get("tools", []),
model=model_name,
file_ids=llm_config.get("file_ids", []),
file_ids=openai_assistant_cfg.get("file_ids", []),
)
else:
logger.warning(
Expand Down Expand Up @@ -135,8 +136,8 @@ def __init__(
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
)

# Check if tools are specified in llm_config
specified_tools = llm_config.get("tools", None)
# Check if tools are specified in assistant_config
specified_tools = openai_assistant_cfg.get("tools", None)

if specified_tools is None:
# Check if the current assistant has tools defined
Expand All @@ -155,7 +156,7 @@ def __init__(
)
self._openai_assistant = self._openai_client.beta.assistants.update(
assistant_id=openai_assistant_id,
tools=llm_config.get("tools", []),
tools=openai_assistant_cfg.get("tools", []),
)
else:
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
Expand Down Expand Up @@ -414,6 +415,10 @@ def assistant_id(self):
def openai_client(self):
return self._openai_client

@property
def openai_assistant(self):
return self._openai_assistant

def get_assistant_instructions(self):
"""Return the assistant instructions from OAI assistant API"""
return self._openai_assistant.instructions
Expand Down Expand Up @@ -472,3 +477,31 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
matching_assistants.append(assistant)

return matching_assistants

def _process_assistant_config(self, llm_config, assistant_config):
"""
Process the llm_config and assistant_config to extract the model name and assistant related configurations.
"""

if llm_config is False:
raise ValueError("llm_config=False is not supported for GPTAssistantAgent.")

if llm_config is None:
openai_client_cfg = {}
else:
openai_client_cfg = copy.deepcopy(llm_config)

if assistant_config is None:
openai_assistant_cfg = {}
else:
openai_assistant_cfg = copy.deepcopy(assistant_config)

# Move the assistant related configurations to assistant_config
# It's important to keep forward compatibility
assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"]
for item in assistant_config_items:
if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
openai_assistant_cfg[item] = openai_client_cfg[item]
openai_client_cfg.pop(item, None)

return openai_client_cfg, openai_assistant_cfg
105 changes: 57 additions & 48 deletions notebook/agentchat_oai_assistant_function_call.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
"import logging\n",
"import os\n",
"\n",
"import requests\n",
"\n",
"from autogen import UserProxyAgent, config_list_from_json\n",
"from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent\n",
"\n",
Expand Down Expand Up @@ -79,32 +77,24 @@
"\n",
"def get_ossinsight(question):\n",
" \"\"\"\n",
" Retrieve the top 10 developers with the most followers on GitHub.\n",
" [Mock] Retrieve the top 10 developers with the most followers on GitHub.\n",
" \"\"\"\n",
" url = \"https://api.ossinsight.io/explorer/answer\"\n",
" headers = {\"Content-Type\": \"application/json\"}\n",
" data = {\"question\": question, \"ignoreCache\": True}\n",
"\n",
" response = requests.post(url, headers=headers, json=data)\n",
" if response.status_code == 200:\n",
" answer = response.json()\n",
" else:\n",
" return f\"Request to {url} failed with status code: {response.status_code}\"\n",
"\n",
" report_components = []\n",
" report_components.append(f\"Question: {answer['question']['title']}\")\n",
" if answer[\"query\"][\"sql\"] != \"\":\n",
" report_components.append(f\"querySQL: {answer['query']['sql']}\")\n",
"\n",
" if answer.get(\"result\", None) is None or len(answer[\"result\"][\"rows\"]) == 0:\n",
" result = \"Result: N/A\"\n",
" else:\n",
" result = \"Result:\\n \" + \"\\n \".join([str(row) for row in answer[\"result\"][\"rows\"]])\n",
" report_components.append(result)\n",
"\n",
" if answer.get(\"error\", None) is not None:\n",
" report_components.append(f\"Error: {answer['error']}\")\n",
" return \"\\n\\n\".join(report_components) + \"\\n\\n\""
" report_components = [\n",
" f\"Question: {question}\",\n",
" \"SQL: SELECT `login` AS `user_login`, `followers` AS `followers` FROM `github_users` ORDER BY `followers` DESC LIMIT 10\",\n",
" \"\"\"Results:\n",
" {'followers': 166730, 'user_login': 'torvalds'}\n",
" {'followers': 86239, 'user_login': 'yyx990803'}\n",
" {'followers': 77611, 'user_login': 'gaearon'}\n",
" {'followers': 72668, 'user_login': 'ruanyf'}\n",
" {'followers': 65415, 'user_login': 'JakeWharton'}\n",
" {'followers': 60972, 'user_login': 'peng-zhihui'}\n",
" {'followers': 58172, 'user_login': 'bradtraversy'}\n",
" {'followers': 52143, 'user_login': 'gustavoguanabara'}\n",
" {'followers': 51542, 'user_login': 'sindresorhus'}\n",
" {'followers': 49621, 'user_login': 'tj'}\"\"\",\n",
" ]\n",
" return \"\\n\" + \"\\n\\n\".join(report_components) + \"\\n\""
]
},
{
Expand All @@ -120,12 +110,24 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"OpenAI client config of GPTAssistantAgent(OSS Analyst) - model: gpt-4-turbo-preview\n",
"GPT Assistant only supports one OpenAI client. Using the first client in the list.\n",
"No matching assistant found, creating a new assistant\n"
]
}
],
"source": [
"assistant_id = os.environ.get(\"ASSISTANT_ID\", None)\n",
"config_list = config_list_from_json(\"OAI_CONFIG_LIST\")\n",
"llm_config = {\n",
" \"config_list\": config_list,\n",
"}\n",
"assistant_config = {\n",
" \"assistant_id\": assistant_id,\n",
" \"tools\": [\n",
" {\n",
Expand All @@ -143,6 +145,7 @@
" \"Please carefully read the context of the conversation to identify the current analysis question or problem that needs addressing.\"\n",
" ),\n",
" llm_config=llm_config,\n",
" assistant_config=assistant_config,\n",
" verbose=True,\n",
")\n",
"oss_analyst.register_function(\n",
Expand Down Expand Up @@ -178,13 +181,14 @@
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION ossinsight_data_api...\u001b[0m\n",
"\u001b[35m\n",
"Input arguments: {'question': 'Who are the top 10 developers with the most followers on GitHub?'}\n",
"Input arguments: {'question': 'Top 10 developers with the most followers'}\n",
"Output:\n",
"Question: Who are the top 10 developers with the most followers on GitHub?\n",
"\n",
"querySQL: SELECT `login` AS `user_login`, `followers` AS `followers` FROM `github_users` ORDER BY `followers` DESC LIMIT 10\n",
"Question: Top 10 developers with the most followers\n",
"\n",
"Result:\n",
"SQL: SELECT `login` AS `user_login`, `followers` AS `followers` FROM `github_users` ORDER BY `followers` DESC LIMIT 10\n",
"\n",
"Results:\n",
" {'followers': 166730, 'user_login': 'torvalds'}\n",
" {'followers': 86239, 'user_login': 'yyx990803'}\n",
" {'followers': 77611, 'user_login': 'gaearon'}\n",
Expand All @@ -195,24 +199,21 @@
" {'followers': 52143, 'user_login': 'gustavoguanabara'}\n",
" {'followers': 51542, 'user_login': 'sindresorhus'}\n",
" {'followers': 49621, 'user_login': 'tj'}\n",
"\n",
"\u001b[0m\n",
"\u001b[33mOSS Analyst\u001b[0m (to user_proxy):\n",
"\n",
"The top 10 developers with the most followers on GitHub are as follows:\n",
"The top 10 developers with the most followers on GitHub are:\n",
"\n",
"1. `torvalds` with 166,730 followers\n",
"2. `yyx990803` with 86,239 followers\n",
"3. `gaearon` with 77,611 followers\n",
"4. `ruanyf` with 72,668 followers\n",
"5. `JakeWharton` with 65,415 followers\n",
"6. `peng-zhihui` with 60,972 followers\n",
"7. `bradtraversy` with 58,172 followers\n",
"8. `gustavoguanabara` with 52,143 followers\n",
"9. `sindresorhus` with 51,542 followers\n",
"10. `tj` with 49,621 followers\n",
"\n",
"These figures indicate the number of followers these developers had at the time of the analysis.\n",
"1. **Linus Torvalds** (`torvalds`) with 166,730 followers\n",
"2. **Evan You** (`yyx990803`) with 86,239 followers\n",
"3. **Dan Abramov** (`gaearon`) with 77,611 followers\n",
"4. **Ruan YiFeng** (`ruanyf`) with 72,668 followers\n",
"5. **Jake Wharton** (`JakeWharton`) with 65,415 followers\n",
"6. **Peng Zhihui** (`peng-zhihui`) with 60,972 followers\n",
"7. **Brad Traversy** (`bradtraversy`) with 58,172 followers\n",
"8. **Gustavo Guanabara** (`gustavoguanabara`) with 52,143 followers\n",
"9. **Sindre Sorhus** (`sindresorhus`) with 51,542 followers\n",
"10. **TJ Holowaychuk** (`tj`) with 49,621 followers\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
Expand All @@ -223,11 +224,18 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mOSS Analyst\u001b[0m (to user_proxy):\n",
"\n",
"It seems you haven't entered a question or a request. Could you please provide more details or specify how I can assist you further?\n",
"It looks like there is no question or prompt for me to respond to. Could you please provide more details or ask a question that you would like assistance with?\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Permanently deleting assistant...\n"
]
}
],
"source": [
Expand All @@ -242,7 +250,8 @@
" max_consecutive_auto_reply=1,\n",
")\n",
"\n",
"user_proxy.initiate_chat(oss_analyst, message=\"Top 10 developers with the most followers\")"
"user_proxy.initiate_chat(oss_analyst, message=\"Top 10 developers with the most followers\")\n",
"oss_analyst.delete_assistant()"
]
}
],
Expand Down
Loading

0 comments on commit edcb635

Please sign in to comment.