Skip to content

Commit

Permalink
support azure assistant api (#1616)
Browse files Browse the repository at this point in the history
* support azure assistant api

* try to add azure testing

* improve testing

* fix testing

* fix code

---------

Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
IANTHEREAL and sonichi authored Feb 15, 2024
1 parent cff9ca9 commit b270a2e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 21 deletions.
11 changes: 9 additions & 2 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,16 @@ def __init__(
- Other kwargs: Except verbose, others are passed directly to ConversableAgent.
"""
# Use AutoGen OpenAIWrapper to create a client
oai_wrapper = OpenAIWrapper(**llm_config)
openai_client_cfg = None
model_name = "gpt-4-1106-preview"
if llm_config and llm_config.get("config_list") is not None and len(llm_config["config_list"]) > 0:
openai_client_cfg = llm_config["config_list"][0].copy()
model_name = openai_client_cfg.pop("model", "gpt-4-1106-preview")

oai_wrapper = OpenAIWrapper(**openai_client_cfg)
if len(oai_wrapper._clients) > 1:
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")

self._openai_client = oai_wrapper._clients[0]._oai_client
openai_assistant_id = llm_config.get("assistant_id", None)
if openai_assistant_id is None:
Expand All @@ -79,7 +86,7 @@ def __init__(
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
model=llm_config.get("model", "gpt-4-1106-preview"),
model=model_name,
file_ids=llm_config.get("file_ids", []),
)
else:
Expand Down
59 changes: 40 additions & 19 deletions test/agentchat/contrib/test_gpt_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,35 @@
skip = False or skip_openai

if not skip:
config_list = autogen.config_list_from_json(
openai_config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"api_type": ["openai"]}
)
aoai_config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "api_version": ["2024-02-15-preview"]},
)


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip,
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_config_list() -> None:
assert len(config_list) > 0
assert len(openai_config_list) > 0
assert len(aoai_config_list) > 0


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip,
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_gpt_assistant_chat() -> None:
for gpt_config in [openai_config_list, aoai_config_list]:
_test_gpt_assistant_chat(gpt_config)


def _test_gpt_assistant_chat(gpt_config) -> None:
ossinsight_api_schema = {
"name": "ossinsight_data_api",
"parameters": {
Expand All @@ -64,7 +75,7 @@ def ask_ossinsight(question: str) -> str:
name = f"For test_gpt_assistant_chat {uuid.uuid4()}"
analyst = GPTAssistantAgent(
name=name,
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": config_list},
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": gpt_config},
instructions="Hello, Open Source Project Analyst. You'll conduct comprehensive evaluations of open source projects or organizations on the GitHub platform",
)
try:
Expand All @@ -90,7 +101,7 @@ def ask_ossinsight(question: str) -> str:
# check the question asked
ask_ossinsight_mock.assert_called_once()
question_asked = ask_ossinsight_mock.call_args[0][0].lower()
for word in "microsoft autogen stars github".split(" "):
for word in "microsoft autogen star github".split(" "):
assert word in question_asked

# check the answer
Expand All @@ -108,6 +119,11 @@ def ask_ossinsight(question: str) -> str:
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_get_assistant_instructions() -> None:
for gpt_config in [openai_config_list, aoai_config_list]:
_test_get_assistant_instructions(gpt_config)


def _test_get_assistant_instructions(gpt_config) -> None:
"""
Test function to create a new GPTAssistantAgent, set its instructions, retrieve the instructions,
and assert that the retrieved instructions match the set instructions.
Expand All @@ -117,7 +133,7 @@ def test_get_assistant_instructions() -> None:
name,
instructions="This is a test",
llm_config={
"config_list": config_list,
"config_list": gpt_config,
},
)

Expand All @@ -132,6 +148,11 @@ def test_get_assistant_instructions() -> None:
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_gpt_assistant_instructions_overwrite() -> None:
for gpt_config in [openai_config_list, aoai_config_list]:
_test_gpt_assistant_instructions_overwrite(gpt_config)


def _test_gpt_assistant_instructions_overwrite(gpt_config) -> None:
"""
Test that the instructions of a GPTAssistantAgent can be overwritten or not depending on the value of the
`overwrite_instructions` parameter when creating a new assistant with the same ID.
Expand All @@ -151,7 +172,7 @@ def test_gpt_assistant_instructions_overwrite() -> None:
name,
instructions=instructions1,
llm_config={
"config_list": config_list,
"config_list": gpt_config,
},
)

Expand All @@ -161,7 +182,7 @@ def test_gpt_assistant_instructions_overwrite() -> None:
name,
instructions=instructions2,
llm_config={
"config_list": config_list,
"config_list": gpt_config,
"assistant_id": assistant_id,
},
overwrite_instructions=True,
Expand Down Expand Up @@ -191,7 +212,7 @@ def test_gpt_assistant_existing_no_instructions() -> None:
name,
instructions=instructions,
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
},
)

Expand All @@ -202,7 +223,7 @@ def test_gpt_assistant_existing_no_instructions() -> None:
assistant = GPTAssistantAgent(
name,
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
"assistant_id": assistant_id,
},
)
Expand All @@ -225,15 +246,15 @@ def test_get_assistant_files() -> None:
and assert that the retrieved instructions match the set instructions.
"""
current_file_path = os.path.abspath(__file__)
openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
name = f"For test_get_assistant_files {uuid.uuid4()}"

assistant = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
"tools": [{"type": "retrieval"}],
"file_ids": [file.id],
},
Expand Down Expand Up @@ -274,7 +295,7 @@ def test_assistant_retrieval() -> None:
"description": "This is a test function 2",
}

openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
current_file_path = os.path.abspath(__file__)

file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
Expand All @@ -289,7 +310,7 @@ def test_assistant_retrieval() -> None:
{"type": "code_interpreter"},
],
"file_ids": [file_1.id, file_2.id],
"config_list": config_list,
"config_list": openai_config_list,
}

name = f"For test_assistant_retrieval {uuid.uuid4()}"
Expand Down Expand Up @@ -350,7 +371,7 @@ def test_assistant_mismatch_retrieval() -> None:
"description": "This is a test function 3",
}

openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
current_file_path = os.path.abspath(__file__)
file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
Expand All @@ -364,7 +385,7 @@ def test_assistant_mismatch_retrieval() -> None:
{"type": "code_interpreter"},
],
"file_ids": [file_1.id, file_2.id],
"config_list": config_list,
"config_list": openai_config_list,
}

name = f"For test_assistant_retrieval {uuid.uuid4()}"
Expand Down Expand Up @@ -400,7 +421,7 @@ def test_assistant_mismatch_retrieval() -> None:
{"type": "function", "function": function_1_schema},
],
"file_ids": [file_2.id],
"config_list": config_list,
"config_list": openai_config_list,
}
assistant_file_ids_mismatch = GPTAssistantAgent(
name,
Expand All @@ -418,7 +439,7 @@ def test_assistant_mismatch_retrieval() -> None:
{"type": "function", "function": function_3_schema},
],
"file_ids": [file_2.id, file_1.id],
"config_list": config_list,
"config_list": openai_config_list,
}
assistant_tools_mistaching = GPTAssistantAgent(
name,
Expand Down Expand Up @@ -536,7 +557,7 @@ def test_gpt_assistant_tools_overwrite() -> None:
assistant_org = GPTAssistantAgent(
name,
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
"tools": original_tools,
},
)
Expand All @@ -548,7 +569,7 @@ def test_gpt_assistant_tools_overwrite() -> None:
assistant = GPTAssistantAgent(
name,
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
"assistant_id": assistant_id,
"tools": new_tools,
},
Expand Down

0 comments on commit b270a2e

Please sign in to comment.