Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate openai assistant related config items from llm_config #1964

Closed
wants to merge 12 commits into from
69 changes: 51 additions & 18 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
name="GPT Assistant",
instructions: Optional[str] = None,
llm_config: Optional[Union[Dict, bool]] = None,
assistant_config: Optional[Dict] = None,
overwrite_instructions: bool = False,
overwrite_tools: bool = False,
**kwargs,
Expand All @@ -43,8 +44,9 @@ def __init__(
AssistantAgent.DEFAULT_SYSTEM_MESSAGE. If the assistant exists, the
system message will be set to the existing assistant instructions.
llm_config (dict or False): llm inference configuration.
- assistant_id: ID of the assistant to use. If None, a new assistant will be created.
- model: Model to use for the assistant (gpt-4-1106-preview, gpt-3.5-turbo-1106).
assistant_config
- assistant_id: ID of the assistant to use. If None, a new assistant will be created.
- check_every_ms: check thread run status interval
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
Expand All @@ -57,23 +59,19 @@ def __init__(
"""

self._verbose = kwargs.pop("verbose", False)
openai_client_cfg, openai_assistant_cfg = self._process_assistant_config(llm_config, assistant_config)

super().__init__(
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=llm_config, **kwargs
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=openai_client_cfg, **kwargs
)

if llm_config is False:
raise ValueError("llm_config=False is not supported for GPTAssistantAgent.")
# Use AutooGen OpenAIWrapper to create a client
openai_client_cfg = copy.deepcopy(llm_config)
# Use the class variable
model_name = GPTAssistantAgent.DEFAULT_MODEL_NAME

# GPTAssistantAgent's azure_deployment param may cause NotFoundError (404) in client.beta.assistants.list()
# See: https://github.com/microsoft/autogen/pull/1721
model_name = self.DEFAULT_MODEL_NAME
if openai_client_cfg.get("config_list") is not None and len(openai_client_cfg["config_list"]) > 0:
model_name = openai_client_cfg["config_list"][0].pop("model", GPTAssistantAgent.DEFAULT_MODEL_NAME)
model_name = openai_client_cfg["config_list"][0].pop("model", self.DEFAULT_MODEL_NAME)
else:
model_name = openai_client_cfg.pop("model", GPTAssistantAgent.DEFAULT_MODEL_NAME)
model_name = openai_client_cfg.pop("model", self.DEFAULT_MODEL_NAME)

logger.warning("OpenAI client config of GPTAssistantAgent(%s) - model: %s", name, model_name)

Expand All @@ -82,14 +80,17 @@ def __init__(
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")

self._openai_client = oai_wrapper._clients[0]._oai_client
openai_assistant_id = llm_config.get("assistant_id", None)
openai_assistant_id = openai_assistant_cfg.get("assistant_id", None)
if openai_assistant_id is None:
# try to find assistant by name first
candidate_assistants = retrieve_assistants_by_name(self._openai_client, name)
if len(candidate_assistants) > 0:
# Filter out candidates with the same name but different instructions, file IDs, and function names.
candidate_assistants = self.find_matching_assistant(
candidate_assistants, instructions, llm_config.get("tools", []), llm_config.get("file_ids", [])
candidate_assistants,
instructions,
openai_assistant_cfg.get("tools", []),
openai_assistant_cfg.get("file_ids", []),
)

if len(candidate_assistants) == 0:
Expand All @@ -103,9 +104,9 @@ def __init__(
self._openai_assistant = self._openai_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
tools=openai_assistant_cfg.get("tools", []),
model=model_name,
file_ids=llm_config.get("file_ids", []),
file_ids=openai_assistant_cfg.get("file_ids", []),
)
else:
logger.warning(
Expand Down Expand Up @@ -135,8 +136,8 @@ def __init__(
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
)

# Check if tools are specified in llm_config
specified_tools = llm_config.get("tools", None)
# Check if tools are specified in assistant_config
specified_tools = openai_assistant_cfg.get("tools", None)

if specified_tools is None:
# Check if the current assistant has tools defined
Expand All @@ -155,7 +156,7 @@ def __init__(
)
self._openai_assistant = self._openai_client.beta.assistants.update(
assistant_id=openai_assistant_id,
tools=llm_config.get("tools", []),
tools=openai_assistant_cfg.get("tools", []),
)
else:
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
Expand Down Expand Up @@ -414,6 +415,10 @@ def assistant_id(self):
def openai_client(self):
return self._openai_client

@property
def openai_assistant(self):
return self._openai_assistant

def get_assistant_instructions(self):
"""Return the assistant instructions from OAI assistant API"""
return self._openai_assistant.instructions
Expand Down Expand Up @@ -472,3 +477,31 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
matching_assistants.append(assistant)

return matching_assistants

def _process_assistant_config(self, llm_config, assistant_config):
"""
Process the llm_config and assistant_config to extract the model name and assistant related configurations.
"""

if llm_config is False:
raise ValueError("llm_config=False is not supported for GPTAssistantAgent.")

if llm_config is None:
openai_client_cfg = {}
else:
openai_client_cfg = copy.deepcopy(llm_config)

if assistant_config is None:
openai_assistant_cfg = {}
else:
openai_assistant_cfg = copy.deepcopy(assistant_config)

# Move the assistant related configurations to assistant_config
# It's important to keep forward compatibility
assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"]
for item in assistant_config_items:
if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
openai_assistant_cfg[item] = openai_client_cfg[item]
openai_client_cfg.pop(item, None)

return openai_client_cfg, openai_assistant_cfg
105 changes: 57 additions & 48 deletions notebook/agentchat_oai_assistant_function_call.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
"import logging\n",
"import os\n",
"\n",
"import requests\n",
"\n",
"from autogen import UserProxyAgent, config_list_from_json\n",
"from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent\n",
"\n",
Expand Down Expand Up @@ -79,32 +77,24 @@
"\n",
"def get_ossinsight(question):\n",
" \"\"\"\n",
" Retrieve the top 10 developers with the most followers on GitHub.\n",
" [Mock] Retrieve the top 10 developers with the most followers on GitHub.\n",
" \"\"\"\n",
" url = \"https://api.ossinsight.io/explorer/answer\"\n",
" headers = {\"Content-Type\": \"application/json\"}\n",
" data = {\"question\": question, \"ignoreCache\": True}\n",
"\n",
" response = requests.post(url, headers=headers, json=data)\n",
" if response.status_code == 200:\n",
" answer = response.json()\n",
" else:\n",
" return f\"Request to {url} failed with status code: {response.status_code}\"\n",
"\n",
" report_components = []\n",
" report_components.append(f\"Question: {answer['question']['title']}\")\n",
" if answer[\"query\"][\"sql\"] != \"\":\n",
" report_components.append(f\"querySQL: {answer['query']['sql']}\")\n",
"\n",
" if answer.get(\"result\", None) is None or len(answer[\"result\"][\"rows\"]) == 0:\n",
" result = \"Result: N/A\"\n",
" else:\n",
" result = \"Result:\\n \" + \"\\n \".join([str(row) for row in answer[\"result\"][\"rows\"]])\n",
" report_components.append(result)\n",
"\n",
" if answer.get(\"error\", None) is not None:\n",
" report_components.append(f\"Error: {answer['error']}\")\n",
" return \"\\n\\n\".join(report_components) + \"\\n\\n\""
" report_components = [\n",
" f\"Question: {question}\",\n",
" \"SQL: SELECT `login` AS `user_login`, `followers` AS `followers` FROM `github_users` ORDER BY `followers` DESC LIMIT 10\",\n",
" \"\"\"Results:\n",
" {'followers': 166730, 'user_login': 'torvalds'}\n",
" {'followers': 86239, 'user_login': 'yyx990803'}\n",
" {'followers': 77611, 'user_login': 'gaearon'}\n",
" {'followers': 72668, 'user_login': 'ruanyf'}\n",
" {'followers': 65415, 'user_login': 'JakeWharton'}\n",
" {'followers': 60972, 'user_login': 'peng-zhihui'}\n",
" {'followers': 58172, 'user_login': 'bradtraversy'}\n",
" {'followers': 52143, 'user_login': 'gustavoguanabara'}\n",
" {'followers': 51542, 'user_login': 'sindresorhus'}\n",
" {'followers': 49621, 'user_login': 'tj'}\"\"\",\n",
" ]\n",
" return \"\\n\" + \"\\n\\n\".join(report_components) + \"\\n\""
]
},
{
Expand All @@ -120,12 +110,24 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"OpenAI client config of GPTAssistantAgent(OSS Analyst) - model: gpt-4-turbo-preview\n",
"GPT Assistant only supports one OpenAI client. Using the first client in the list.\n",
"No matching assistant found, creating a new assistant\n"
]
}
],
"source": [
"assistant_id = os.environ.get(\"ASSISTANT_ID\", None)\n",
"config_list = config_list_from_json(\"OAI_CONFIG_LIST\")\n",
"llm_config = {\n",
" \"config_list\": config_list,\n",
"}\n",
"assistant_config = {\n",
" \"assistant_id\": assistant_id,\n",
" \"tools\": [\n",
" {\n",
Expand All @@ -143,6 +145,7 @@
" \"Please carefully read the context of the conversation to identify the current analysis question or problem that needs addressing.\"\n",
" ),\n",
" llm_config=llm_config,\n",
" assistant_config=assistant_config,\n",
" verbose=True,\n",
")\n",
"oss_analyst.register_function(\n",
Expand Down Expand Up @@ -178,13 +181,14 @@
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION ossinsight_data_api...\u001b[0m\n",
"\u001b[35m\n",
"Input arguments: {'question': 'Who are the top 10 developers with the most followers on GitHub?'}\n",
"Input arguments: {'question': 'Top 10 developers with the most followers'}\n",
"Output:\n",
"Question: Who are the top 10 developers with the most followers on GitHub?\n",
"\n",
"querySQL: SELECT `login` AS `user_login`, `followers` AS `followers` FROM `github_users` ORDER BY `followers` DESC LIMIT 10\n",
"Question: Top 10 developers with the most followers\n",
"\n",
"Result:\n",
"SQL: SELECT `login` AS `user_login`, `followers` AS `followers` FROM `github_users` ORDER BY `followers` DESC LIMIT 10\n",
"\n",
"Results:\n",
" {'followers': 166730, 'user_login': 'torvalds'}\n",
" {'followers': 86239, 'user_login': 'yyx990803'}\n",
" {'followers': 77611, 'user_login': 'gaearon'}\n",
Expand All @@ -195,24 +199,21 @@
" {'followers': 52143, 'user_login': 'gustavoguanabara'}\n",
" {'followers': 51542, 'user_login': 'sindresorhus'}\n",
" {'followers': 49621, 'user_login': 'tj'}\n",
"\n",
"\u001b[0m\n",
"\u001b[33mOSS Analyst\u001b[0m (to user_proxy):\n",
"\n",
"The top 10 developers with the most followers on GitHub are as follows:\n",
"The top 10 developers with the most followers on GitHub are:\n",
"\n",
"1. `torvalds` with 166,730 followers\n",
"2. `yyx990803` with 86,239 followers\n",
"3. `gaearon` with 77,611 followers\n",
"4. `ruanyf` with 72,668 followers\n",
"5. `JakeWharton` with 65,415 followers\n",
"6. `peng-zhihui` with 60,972 followers\n",
"7. `bradtraversy` with 58,172 followers\n",
"8. `gustavoguanabara` with 52,143 followers\n",
"9. `sindresorhus` with 51,542 followers\n",
"10. `tj` with 49,621 followers\n",
"\n",
"These figures indicate the number of followers these developers had at the time of the analysis.\n",
"1. **Linus Torvalds** (`torvalds`) with 166,730 followers\n",
"2. **Evan You** (`yyx990803`) with 86,239 followers\n",
"3. **Dan Abramov** (`gaearon`) with 77,611 followers\n",
"4. **Ruan YiFeng** (`ruanyf`) with 72,668 followers\n",
"5. **Jake Wharton** (`JakeWharton`) with 65,415 followers\n",
"6. **Peng Zhihui** (`peng-zhihui`) with 60,972 followers\n",
"7. **Brad Traversy** (`bradtraversy`) with 58,172 followers\n",
"8. **Gustavo Guanabara** (`gustavoguanabara`) with 52,143 followers\n",
"9. **Sindre Sorhus** (`sindresorhus`) with 51,542 followers\n",
"10. **TJ Holowaychuk** (`tj`) with 49,621 followers\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
Expand All @@ -223,11 +224,18 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mOSS Analyst\u001b[0m (to user_proxy):\n",
"\n",
"It seems you haven't entered a question or a request. Could you please provide more details or specify how I can assist you further?\n",
"It looks like there is no question or prompt for me to respond to. Could you please provide more details or ask a question that you would like assistance with?\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Permanently deleting assistant...\n"
]
}
],
"source": [
Expand All @@ -242,7 +250,8 @@
" max_consecutive_auto_reply=1,\n",
")\n",
"\n",
"user_proxy.initiate_chat(oss_analyst, message=\"Top 10 developers with the most followers\")"
"user_proxy.initiate_chat(oss_analyst, message=\"Top 10 developers with the most followers\")\n",
"oss_analyst.delete_assistant()"
]
}
],
Expand Down
Loading
Loading