From b270a2e46793ae51923d3babcf4d7f0c9ea61ed9 Mon Sep 17 00:00:00 2001 From: Ian Date: Thu, 15 Feb 2024 13:29:08 +0800 Subject: [PATCH] support azure assistant api (#1616) * support azure assistant api * try to add azure testing * improve testing * fix testing * fix code --------- Co-authored-by: Chi Wang --- .../agentchat/contrib/gpt_assistant_agent.py | 11 +++- test/agentchat/contrib/test_gpt_assistant.py | 59 +++++++++++++------ 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py index dc2967e103ef..e5916781cd67 100644 --- a/autogen/agentchat/contrib/gpt_assistant_agent.py +++ b/autogen/agentchat/contrib/gpt_assistant_agent.py @@ -53,9 +53,16 @@ def __init__( - Other kwargs: Except verbose, others are passed directly to ConversableAgent. """ # Use AutoGen OpenAIWrapper to create a client - oai_wrapper = OpenAIWrapper(**llm_config) + openai_client_cfg = None + model_name = "gpt-4-1106-preview" + if llm_config and llm_config.get("config_list") is not None and len(llm_config["config_list"]) > 0: + openai_client_cfg = llm_config["config_list"][0].copy() + model_name = openai_client_cfg.pop("model", "gpt-4-1106-preview") + + oai_wrapper = OpenAIWrapper(**openai_client_cfg) if len(oai_wrapper._clients) > 1: logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.") + self._openai_client = oai_wrapper._clients[0]._oai_client openai_assistant_id = llm_config.get("assistant_id", None) if openai_assistant_id is None: @@ -79,7 +86,7 @@ def __init__( name=name, instructions=instructions, tools=llm_config.get("tools", []), - model=llm_config.get("model", "gpt-4-1106-preview"), + model=model_name, file_ids=llm_config.get("file_ids", []), ) else: diff --git a/test/agentchat/contrib/test_gpt_assistant.py b/test/agentchat/contrib/test_gpt_assistant.py index 92e12558afc7..ef7d74732091 100644 --- a/test/agentchat/contrib/test_gpt_assistant.py +++ b/test/agentchat/contrib/test_gpt_assistant.py @@ -23,9 +23,14 @@ skip = False or skip_openai if not skip: - config_list = autogen.config_list_from_json( + openai_config_list = autogen.config_list_from_json( OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"api_type": ["openai"]} ) + aoai_config_list = autogen.config_list_from_json( + OAI_CONFIG_LIST, + file_location=KEY_LOC, + filter_dict={"api_type": ["azure"], "api_version": ["2024-02-15-preview"]}, + ) @pytest.mark.skipif( @@ -33,7 +38,8 @@ reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip", ) def test_config_list() -> None: - assert len(config_list) > 0 + assert len(openai_config_list) > 0 + assert len(aoai_config_list) > 0 @pytest.mark.skipif( @@ -41,6 +47,11 @@ def test_config_list() -> None: reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip", ) def test_gpt_assistant_chat() -> None: + for gpt_config in [openai_config_list, aoai_config_list]: + _test_gpt_assistant_chat(gpt_config) + + +def _test_gpt_assistant_chat(gpt_config) -> None: ossinsight_api_schema = { "name": "ossinsight_data_api", "parameters": { @@ -64,7 +75,7 @@ def ask_ossinsight(question: str) -> str: name = f"For test_gpt_assistant_chat {uuid.uuid4()}" analyst = GPTAssistantAgent( name=name, - llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": config_list}, + llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": gpt_config}, instructions="Hello, Open Source Project Analyst. You'll conduct comprehensive evaluations of open source projects or organizations on the GitHub platform", ) try: @@ -90,7 +101,7 @@ def ask_ossinsight(question: str) -> str: # check the question asked ask_ossinsight_mock.assert_called_once() question_asked = ask_ossinsight_mock.call_args[0][0].lower() - for word in "microsoft autogen stars github".split(" "): + for word in "microsoft autogen star github".split(" "): assert word in question_asked # check the answer @@ -108,6 +119,11 @@ def ask_ossinsight(question: str) -> str: reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip", ) def test_get_assistant_instructions() -> None: + for gpt_config in [openai_config_list, aoai_config_list]: + _test_get_assistant_instructions(gpt_config) + + +def _test_get_assistant_instructions(gpt_config) -> None: """ Test function to create a new GPTAssistantAgent, set its instructions, retrieve the instructions, and assert that the retrieved instructions match the set instructions. @@ -117,7 +133,7 @@ def test_get_assistant_instructions() -> None: name, instructions="This is a test", llm_config={ - "config_list": config_list, + "config_list": gpt_config, }, ) @@ -132,6 +148,11 @@ def test_get_assistant_instructions() -> None: reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip", ) def test_gpt_assistant_instructions_overwrite() -> None: + for gpt_config in [openai_config_list, aoai_config_list]: + _test_gpt_assistant_instructions_overwrite(gpt_config) + + +def _test_gpt_assistant_instructions_overwrite(gpt_config) -> None: """ Test that the instructions of a GPTAssistantAgent can be overwritten or not depending on the value of the `overwrite_instructions` parameter when creating a new assistant with the same ID. @@ -151,7 +172,7 @@ def test_gpt_assistant_instructions_overwrite() -> None: name, instructions=instructions1, llm_config={ - "config_list": config_list, + "config_list": gpt_config, }, ) @@ -161,7 +182,7 @@ def test_gpt_assistant_instructions_overwrite() -> None: name, instructions=instructions2, llm_config={ - "config_list": config_list, + "config_list": gpt_config, "assistant_id": assistant_id, }, overwrite_instructions=True, @@ -191,7 +212,7 @@ def test_gpt_assistant_existing_no_instructions() -> None: name, instructions=instructions, llm_config={ - "config_list": config_list, + "config_list": openai_config_list, }, ) @@ -202,7 +223,7 @@ def test_gpt_assistant_existing_no_instructions() -> None: assistant = GPTAssistantAgent( name, llm_config={ - "config_list": config_list, + "config_list": openai_config_list, "assistant_id": assistant_id, }, ) @@ -225,7 +246,7 @@ def test_get_assistant_files() -> None: and assert that the retrieved instructions match the set instructions. """ current_file_path = os.path.abspath(__file__) - openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client + openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") name = f"For test_get_assistant_files {uuid.uuid4()}" @@ -233,7 +254,7 @@ def test_get_assistant_files() -> None: name, instructions="This is a test", llm_config={ - "config_list": config_list, + "config_list": openai_config_list, "tools": [{"type": "retrieval"}], "file_ids": [file.id], }, @@ -274,7 +295,7 @@ def test_assistant_retrieval() -> None: "description": "This is a test function 2", } - openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client + openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client current_file_path = os.path.abspath(__file__) file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") @@ -289,7 +310,7 @@ def test_assistant_retrieval() -> None: {"type": "code_interpreter"}, ], "file_ids": [file_1.id, file_2.id], - "config_list": config_list, + "config_list": openai_config_list, } name = f"For test_assistant_retrieval {uuid.uuid4()}" @@ -350,7 +371,7 @@ def test_assistant_mismatch_retrieval() -> None: "description": "This is a test function 3", } - openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client + openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client current_file_path = os.path.abspath(__file__) file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") @@ -364,7 +385,7 @@ def test_assistant_mismatch_retrieval() -> None: {"type": "code_interpreter"}, ], "file_ids": [file_1.id, file_2.id], - "config_list": config_list, + "config_list": openai_config_list, } name = f"For test_assistant_retrieval {uuid.uuid4()}" @@ -400,7 +421,7 @@ def test_assistant_mismatch_retrieval() -> None: {"type": "function", "function": function_1_schema}, ], "file_ids": [file_2.id], - "config_list": config_list, + "config_list": openai_config_list, } assistant_file_ids_mismatch = GPTAssistantAgent( name, @@ -418,7 +439,7 @@ def test_assistant_mismatch_retrieval() -> None: {"type": "function", "function": function_3_schema}, ], "file_ids": [file_2.id, file_1.id], - "config_list": config_list, + "config_list": openai_config_list, } assistant_tools_mistaching = GPTAssistantAgent( name, @@ -536,7 +557,7 @@ def test_gpt_assistant_tools_overwrite() -> None: assistant_org = GPTAssistantAgent( name, llm_config={ - "config_list": config_list, + "config_list": openai_config_list, "tools": original_tools, }, ) @@ -548,7 +569,7 @@ def test_gpt_assistant_tools_overwrite() -> None: assistant = GPTAssistantAgent( name, llm_config={ - "config_list": config_list, + "config_list": openai_config_list, "assistant_id": assistant_id, "tools": new_tools, },