Skip to content

Commit a41182a

Browse files
IANTHEREALekzhu
andauthored
Support openai assistant v2 API (#2466)
* adapted to openai assistant v2 api * fix comments * format code * fix ci * Update autogen/agentchat/contrib/gpt_assistant_agent.py Co-authored-by: Eric Zhu <[email protected]> --------- Co-authored-by: Eric Zhu <[email protected]>
1 parent 2daae42 commit a41182a

File tree

4 files changed

+150
-53
lines changed

4 files changed

+150
-53
lines changed

autogen/agentchat/contrib/gpt_assistant_agent.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from autogen import OpenAIWrapper
1111
from autogen.agentchat.agent import Agent
1212
from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent
13-
from autogen.oai.openai_utils import retrieve_assistants_by_name
13+
from autogen.oai.openai_utils import create_gpt_assistant, retrieve_assistants_by_name, update_gpt_assistant
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -50,7 +50,8 @@ def __init__(
5050
- check_every_ms: check thread run status interval
5151
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
5252
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
53-
- file_ids: files used by retrieval in run
53+
- file_ids: (Deprecated) files used by retrieval in run. It is Deprecated, use tool_resources instead. https://platform.openai.com/docs/assistants/migration/what-has-changed.
54+
- tool_resources: A set of resources that are used by the assistant's tools. The resources are specific to the type of tool.
5455
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
5556
overwrite_tools (bool): whether to overwrite the tools of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
5657
kwargs (dict): Additional configuration options for the agent.
@@ -90,7 +91,6 @@ def __init__(
9091
candidate_assistants,
9192
instructions,
9293
openai_assistant_cfg.get("tools", []),
93-
openai_assistant_cfg.get("file_ids", []),
9494
)
9595

9696
if len(candidate_assistants) == 0:
@@ -101,12 +101,12 @@ def __init__(
101101
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
102102
)
103103
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
104-
self._openai_assistant = self._openai_client.beta.assistants.create(
104+
self._openai_assistant = create_gpt_assistant(
105+
self._openai_client,
105106
name=name,
106107
instructions=instructions,
107-
tools=openai_assistant_cfg.get("tools", []),
108108
model=model_name,
109-
file_ids=openai_assistant_cfg.get("file_ids", []),
109+
assistant_config=openai_assistant_cfg,
110110
)
111111
else:
112112
logger.warning(
@@ -127,9 +127,12 @@ def __init__(
127127
logger.warning(
128128
"overwrite_instructions is True. Provided instructions will be used and will modify the assistant in the API"
129129
)
130-
self._openai_assistant = self._openai_client.beta.assistants.update(
130+
self._openai_assistant = update_gpt_assistant(
131+
self._openai_client,
131132
assistant_id=openai_assistant_id,
132-
instructions=instructions,
133+
assistant_config={
134+
"instructions": instructions,
135+
},
133136
)
134137
else:
135138
logger.warning(
@@ -154,9 +157,13 @@ def __init__(
154157
logger.warning(
155158
"overwrite_tools is True. Provided tools will be used and will modify the assistant in the API"
156159
)
157-
self._openai_assistant = self._openai_client.beta.assistants.update(
160+
self._openai_assistant = update_gpt_assistant(
161+
self._openai_client,
158162
assistant_id=openai_assistant_id,
159-
tools=openai_assistant_cfg.get("tools", []),
163+
assistant_config={
164+
"tools": specified_tools,
165+
"tool_resources": openai_assistant_cfg.get("tool_resources", None),
166+
},
160167
)
161168
else:
162169
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
@@ -198,6 +205,8 @@ def _invoke_assistant(
198205
assistant_thread = self._openai_threads[sender]
199206
# Process each unread message
200207
for message in pending_messages:
208+
if message["content"].strip() == "":
209+
continue
201210
self._openai_client.beta.threads.messages.create(
202211
thread_id=assistant_thread.id,
203212
content=message["content"],
@@ -426,22 +435,23 @@ def delete_assistant(self):
426435
logger.warning("Permanently deleting assistant...")
427436
self._openai_client.beta.assistants.delete(self.assistant_id)
428437

429-
def find_matching_assistant(self, candidate_assistants, instructions, tools, file_ids):
438+
def find_matching_assistant(self, candidate_assistants, instructions, tools):
430439
"""
431440
Find the matching assistant from a list of candidate assistants.
432-
Filter out candidates with the same name but different instructions, file IDs, and function names.
433-
TODO: implement accurate match based on assistant metadata fields.
441+
Filter out candidates with the same name but different instructions, and function names.
434442
"""
435443
matching_assistants = []
436444

437445
# Preprocess the required tools for faster comparison
438-
required_tool_types = set(tool.get("type") for tool in tools)
446+
required_tool_types = set(
447+
"file_search" if tool.get("type") in ["retrieval", "file_search"] else tool.get("type") for tool in tools
448+
)
449+
439450
required_function_names = set(
440451
tool.get("function", {}).get("name")
441452
for tool in tools
442-
if tool.get("type") not in ["code_interpreter", "retrieval"]
453+
if tool.get("type") not in ["code_interpreter", "retrieval", "file_search"]
443454
)
444-
required_file_ids = set(file_ids) # Convert file_ids to a set for unordered comparison
445455

446456
for assistant in candidate_assistants:
447457
# Check if instructions are similar
@@ -454,11 +464,12 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
454464
continue
455465

456466
# Preprocess the assistant's tools
457-
assistant_tool_types = set(tool.type for tool in assistant.tools)
467+
assistant_tool_types = set(
468+
"file_search" if tool.type in ["retrieval", "file_search"] else tool.type for tool in assistant.tools
469+
)
458470
assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function"))
459-
assistant_file_ids = set(getattr(assistant, "file_ids", [])) # Convert to set for comparison
460471

461-
# Check if the tool types, function names, and file IDs match
472+
# Check if the tool types, function names match
462473
if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names:
463474
logger.warning(
464475
"tools not match, skip assistant(%s): tools %s, functions %s",
@@ -467,9 +478,6 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
467478
assistant_function_names,
468479
)
469480
continue
470-
if required_file_ids != assistant_file_ids:
471-
logger.warning("file_ids not match, skip assistant(%s): %s", assistant.id, assistant_file_ids)
472-
continue
473481

474482
# Append assistant to matching list if all conditions are met
475483
matching_assistants.append(assistant)
@@ -496,7 +504,7 @@ def _process_assistant_config(self, llm_config, assistant_config):
496504

497505
# Move the assistant related configurations to assistant_config
498506
# It's important to keep forward compatibility
499-
assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"]
507+
assistant_config_items = ["assistant_id", "tools", "file_ids", "tool_resources", "check_every_ms"]
500508
for item in assistant_config_items:
501509
if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
502510
openai_assistant_cfg[item] = openai_client_cfg[item]

autogen/oai/openai_utils.py

+103
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
import importlib.metadata
12
import json
23
import logging
34
import os
45
import re
56
import tempfile
7+
import time
68
from pathlib import Path
79
from typing import Any, Dict, List, Optional, Set, Union
810

911
from dotenv import find_dotenv, load_dotenv
1012
from openai import OpenAI
1113
from openai.types.beta.assistant import Assistant
14+
from packaging.version import parse
1215

1316
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
1417
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
@@ -675,3 +678,103 @@ def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
675678
if assistant.name == name:
676679
candidate_assistants.append(assistant)
677680
return candidate_assistants
681+
682+
683+
def detect_gpt_assistant_api_version() -> str:
684+
"""Detect the openai assistant API version"""
685+
oai_version = importlib.metadata.version("openai")
686+
if parse(oai_version) < parse("1.21"):
687+
return "v1"
688+
else:
689+
return "v2"
690+
691+
692+
def create_gpt_vector_store(client: OpenAI, name: str, fild_ids: List[str]) -> Any:
693+
"""Create a openai vector store for gpt assistant"""
694+
695+
vector_store = client.beta.vector_stores.create(name=name)
696+
# poll the status of the file batch for completion.
697+
batch = client.beta.vector_stores.file_batches.create_and_poll(vector_store_id=vector_store.id, file_ids=fild_ids)
698+
699+
if batch.status == "in_progress":
700+
time.sleep(1)
701+
logging.debug(f"file batch status: {batch.file_counts}")
702+
batch = client.beta.vector_stores.file_batches.poll(vector_store_id=vector_store.id, batch_id=batch.id)
703+
704+
if batch.status == "completed":
705+
return vector_store
706+
707+
raise ValueError(f"Failed to upload files to vector store {vector_store.id}:{batch.status}")
708+
709+
710+
def create_gpt_assistant(
711+
client: OpenAI, name: str, instructions: str, model: str, assistant_config: Dict[str, Any]
712+
) -> Assistant:
713+
"""Create a openai gpt assistant"""
714+
715+
assistant_create_kwargs = {}
716+
gpt_assistant_api_version = detect_gpt_assistant_api_version()
717+
tools = assistant_config.get("tools", [])
718+
719+
if gpt_assistant_api_version == "v2":
720+
tool_resources = assistant_config.get("tool_resources", {})
721+
file_ids = assistant_config.get("file_ids")
722+
if tool_resources.get("file_search") is not None and file_ids is not None:
723+
raise ValueError(
724+
"Cannot specify both `tool_resources['file_search']` tool and `file_ids` in the assistant config."
725+
)
726+
727+
# Designed for backwards compatibility for the V1 API
728+
# Instead of V1 AssistantFile, files are attached to Assistants using the tool_resources object.
729+
for tool in tools:
730+
if tool["type"] == "retrieval":
731+
tool["type"] = "file_search"
732+
if file_ids is not None:
733+
# create a vector store for the file search tool
734+
vs = create_gpt_vector_store(client, f"{name}-vectorestore", file_ids)
735+
tool_resources["file_search"] = {
736+
"vector_store_ids": [vs.id],
737+
}
738+
elif tool["type"] == "code_interpreter" and file_ids is not None:
739+
tool_resources["code_interpreter"] = {
740+
"file_ids": file_ids,
741+
}
742+
743+
assistant_create_kwargs["tools"] = tools
744+
if len(tool_resources) > 0:
745+
assistant_create_kwargs["tool_resources"] = tool_resources
746+
else:
747+
# not support forwards compatibility
748+
if "tool_resources" in assistant_config:
749+
raise ValueError("`tool_resources` argument are not supported in the openai assistant V1 API.")
750+
if any(tool["type"] == "file_search" for tool in tools):
751+
raise ValueError(
752+
"`file_search` tool are not supported in the openai assistant V1 API, please use `retrieval`."
753+
)
754+
assistant_create_kwargs["tools"] = tools
755+
assistant_create_kwargs["file_ids"] = assistant_config.get("file_ids", [])
756+
757+
logging.info(f"Creating assistant with config: {assistant_create_kwargs}")
758+
return client.beta.assistants.create(name=name, instructions=instructions, model=model, **assistant_create_kwargs)
759+
760+
761+
def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Dict[str, Any]) -> Assistant:
762+
"""Update openai gpt assistant"""
763+
764+
gpt_assistant_api_version = detect_gpt_assistant_api_version()
765+
assistant_update_kwargs = {}
766+
767+
if assistant_config.get("tools") is not None:
768+
assistant_update_kwargs["tools"] = assistant_config["tools"]
769+
770+
if assistant_config.get("instructions") is not None:
771+
assistant_update_kwargs["instructions"] = assistant_config["instructions"]
772+
773+
if gpt_assistant_api_version == "v2":
774+
if assistant_config.get("tool_resources") is not None:
775+
assistant_update_kwargs["tool_resources"] = assistant_config["tool_resources"]
776+
else:
777+
if assistant_config.get("file_ids") is not None:
778+
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
779+
780+
return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
__version__ = version["__version__"]
1515

1616
install_requires = [
17-
"openai>=1.3,<1.21",
17+
"openai>=1.3",
1818
"diskcache",
1919
"termcolor",
2020
"flaml",

test/agentchat/contrib/test_gpt_assistant.py

+15-29
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import autogen
1212
from autogen import OpenAIWrapper, UserProxyAgent
1313
from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
14-
from autogen.oai.openai_utils import retrieve_assistants_by_name
14+
from autogen.oai.openai_utils import detect_gpt_assistant_api_version, retrieve_assistants_by_name
1515

1616
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
1717
from conftest import reason, skip_openai # noqa: E402
@@ -264,6 +264,7 @@ def test_get_assistant_files() -> None:
264264
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
265265
file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
266266
name = f"For test_get_assistant_files {uuid.uuid4()}"
267+
gpt_assistant_api_version = detect_gpt_assistant_api_version()
267268

268269
# keep it to test older version of assistant config
269270
assistant = GPTAssistantAgent(
@@ -277,10 +278,17 @@ def test_get_assistant_files() -> None:
277278
)
278279

279280
try:
280-
files = assistant.openai_client.beta.assistants.files.list(assistant_id=assistant.assistant_id)
281-
retrieved_file_ids = [fild.id for fild in files]
281+
if gpt_assistant_api_version == "v1":
282+
files = assistant.openai_client.beta.assistants.files.list(assistant_id=assistant.assistant_id)
283+
retrieved_file_ids = [fild.id for fild in files]
284+
elif gpt_assistant_api_version == "v2":
285+
oas_assistant = assistant.openai_client.beta.assistants.retrieve(assistant_id=assistant.assistant_id)
286+
vectorstore_ids = oas_assistant.tool_resources.file_search.vector_store_ids
287+
retrieved_file_ids = []
288+
for vectorstore_id in vectorstore_ids:
289+
files = assistant.openai_client.beta.vector_stores.files.list(vector_store_id=vectorstore_id)
290+
retrieved_file_ids.extend([fild.id for fild in files])
282291
expected_file_id = file.id
283-
284292
finally:
285293
assistant.delete_assistant()
286294
openai_client.files.delete(file.id)
@@ -401,7 +409,7 @@ def test_assistant_mismatch_retrieval() -> None:
401409
"tools": [
402410
{"type": "function", "function": function_1_schema},
403411
{"type": "function", "function": function_2_schema},
404-
{"type": "retrieval"},
412+
{"type": "file_search"},
405413
{"type": "code_interpreter"},
406414
],
407415
"file_ids": [file_1.id, file_2.id],
@@ -411,7 +419,6 @@ def test_assistant_mismatch_retrieval() -> None:
411419
name = f"For test_assistant_retrieval {uuid.uuid4()}"
412420

413421
assistant_first, assistant_instructions_mistaching = None, None
414-
assistant_file_ids_mismatch, assistant_tools_mistaching = None, None
415422
try:
416423
assistant_first = GPTAssistantAgent(
417424
name,
@@ -432,30 +439,11 @@ def test_assistant_mismatch_retrieval() -> None:
432439
)
433440
assert len(candidate_instructions_mistaching) == 2
434441

435-
# test mismatch fild ids
436-
file_ids_mismatch_llm_config = {
437-
"tools": [
438-
{"type": "code_interpreter"},
439-
{"type": "retrieval"},
440-
{"type": "function", "function": function_2_schema},
441-
{"type": "function", "function": function_1_schema},
442-
],
443-
"file_ids": [file_2.id],
444-
"config_list": openai_config_list,
445-
}
446-
assistant_file_ids_mismatch = GPTAssistantAgent(
447-
name,
448-
instructions="This is a test",
449-
llm_config=file_ids_mismatch_llm_config,
450-
)
451-
candidate_file_ids_mismatch = retrieve_assistants_by_name(assistant_file_ids_mismatch.openai_client, name)
452-
assert len(candidate_file_ids_mismatch) == 3
453-
454442
# test tools mismatch
455443
tools_mismatch_llm_config = {
456444
"tools": [
457445
{"type": "code_interpreter"},
458-
{"type": "retrieval"},
446+
{"type": "file_search"},
459447
{"type": "function", "function": function_3_schema},
460448
],
461449
"file_ids": [file_2.id, file_1.id],
@@ -467,15 +455,13 @@ def test_assistant_mismatch_retrieval() -> None:
467455
llm_config=tools_mismatch_llm_config,
468456
)
469457
candidate_tools_mismatch = retrieve_assistants_by_name(assistant_tools_mistaching.openai_client, name)
470-
assert len(candidate_tools_mismatch) == 4
458+
assert len(candidate_tools_mismatch) == 3
471459

472460
finally:
473461
if assistant_first:
474462
assistant_first.delete_assistant()
475463
if assistant_instructions_mistaching:
476464
assistant_instructions_mistaching.delete_assistant()
477-
if assistant_file_ids_mismatch:
478-
assistant_file_ids_mismatch.delete_assistant()
479465
if assistant_tools_mistaching:
480466
assistant_tools_mistaching.delete_assistant()
481467

0 commit comments

Comments
 (0)