Skip to content

Commit 5079366

Browse files
maxim-saplinsonichi
authored andcommitted
Switched to AzureOpenAI for api_type=="azure" (#1232)
* Switched to AzureOpenAI for api_type=="azure" * Setting AzureOpenAI to empty object if no `openai` * extra_ and openai_ kwargs * test_client, support for Azure and "gpt-35-turbo-instruct" * instruct/azure model in test_client_stream * generalize aoai support (#1) * generalize aoai support * Null check, fixing tests * cleanup test --------- Co-authored-by: Maxim Saplin <[email protected]> * Returning back model names for instruct * process model in create * None check --------- Co-authored-by: Chi Wang <[email protected]>
1 parent cecff5a commit 5079366

File tree

3 files changed

+58
-60
lines changed

3 files changed

+58
-60
lines changed

autogen/oai/client.py

+31-43
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from autogen.oai import completion
1313

14-
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
14+
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K
1515
from autogen.token_count_utils import count_token
1616
from autogen._pydantic import model_dump
1717

@@ -21,9 +21,10 @@
2121
except ImportError:
2222
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
2323
OpenAI = object
24+
AzureOpenAI = object
2425
else:
2526
# raises exception if openai>=1 is installed and something is wrong with imports
26-
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
27+
from openai import OpenAI, AzureOpenAI, APIError, __version__ as OPENAIVERSION
2728
from openai.resources import Completions
2829
from openai.types.chat import ChatCompletion
2930
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
@@ -52,8 +53,18 @@ class OpenAIWrapper:
5253
"""A wrapper class for openai client."""
5354

5455
cache_path_root: str = ".cache"
55-
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"}
56+
extra_kwargs = {
57+
"cache_seed",
58+
"filter_func",
59+
"allow_format_str_template",
60+
"context",
61+
"api_version",
62+
"api_type",
63+
"tags",
64+
}
5665
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
66+
aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs)
67+
openai_kwargs = openai_kwargs | aopenai_kwargs
5768
total_usage_summary: Optional[Dict[str, Any]] = None
5869
actual_usage_summary: Optional[Dict[str, Any]] = None
5970

@@ -105,46 +116,10 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base
105116
self._clients = [self._client(extra_kwargs, openai_config)]
106117
self._config_list = [extra_kwargs]
107118

108-
def _process_for_azure(
109-
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
110-
) -> None:
111-
# deal with api_version
112-
query_segment = f"{segment}_query"
113-
headers_segment = f"{segment}_headers"
114-
api_version = extra_kwargs.get("api_version")
115-
if api_version is not None and query_segment not in config:
116-
config[query_segment] = {"api-version": api_version}
117-
if segment == "default":
118-
# remove the api_version from extra_kwargs
119-
extra_kwargs.pop("api_version")
120-
if segment == "extra":
121-
return
122-
# deal with api_type
123-
api_type = extra_kwargs.get("api_type")
124-
if api_type is not None and api_type.startswith("azure") and headers_segment not in config:
125-
api_key = config.get("api_key", os.environ.get("AZURE_OPENAI_API_KEY"))
126-
config[headers_segment] = {"api-key": api_key}
127-
# remove the api_type from extra_kwargs
128-
extra_kwargs.pop("api_type")
129-
# deal with model
130-
model = extra_kwargs.get("model")
131-
if model is None:
132-
return
133-
if "gpt-3.5" in model:
134-
# hack for azure gpt-3.5
135-
extra_kwargs["model"] = model = model.replace("gpt-3.5", "gpt-35")
136-
base_url = config.get("base_url")
137-
if base_url is None:
138-
raise ValueError("to use azure openai api, base_url must be specified.")
139-
suffix = f"/openai/deployments/{model}"
140-
if not base_url.endswith(suffix):
141-
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix
142-
143119
def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
144120
"""Separate the config into openai_config and extra_kwargs."""
145121
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
146122
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
147-
self._process_for_azure(openai_config, extra_kwargs)
148123
return openai_config, extra_kwargs
149124

150125
def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@@ -156,10 +131,22 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any
156131
def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
157132
"""Create a client with the given config to override openai_config,
158133
after removing extra kwargs.
134+
135+
For Azure models/deployment names there's a convenience modification of model removing dots in
136+
the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
137+
"gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
138+
from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
159139
"""
160140
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
161-
self._process_for_azure(openai_config, config)
162-
client = OpenAI(**openai_config)
141+
api_type = config.get("api_type")
142+
if api_type is not None and api_type.startswith("azure"):
143+
openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
144+
if openai_config["azure_deployment"] is not None:
145+
openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
146+
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
147+
client = AzureOpenAI(**openai_config)
148+
else:
149+
client = OpenAI(**openai_config)
163150
return client
164151

165152
@classmethod
@@ -242,8 +229,9 @@ def yes_or_no_filter(context, response):
242229
full_config = {**config, **self._config_list[i]}
243230
# separate the config into create_config and extra_kwargs
244231
create_config, extra_kwargs = self._separate_create_config(full_config)
245-
# process for azure
246-
self._process_for_azure(create_config, extra_kwargs, "extra")
232+
api_type = extra_kwargs.get("api_type")
233+
if api_type and api_type.startswith("azure") and "model" in create_config:
234+
create_config["model"] = create_config["model"].replace(".", "")
247235
# construct the create params
248236
params = self._construct_create_params(create_config, extra_kwargs)
249237
# get the cache_seed, filter_func and context

test/oai/test_client.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,15 @@ def test_aoai_chat_completion():
3131
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
3232
)
3333
client = OpenAIWrapper(config_list=config_list)
34-
# for config in config_list:
35-
# print(config)
36-
# client = OpenAIWrapper(**config)
37-
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
34+
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
35+
print(response)
36+
print(client.extract_text_or_completion_object(response))
37+
38+
# test dialect
39+
config = config_list[0]
40+
config["azure_deployment"] = config["model"]
41+
config["azure_endpoint"] = config.pop("base_url")
42+
client = OpenAIWrapper(**config)
3843
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
3944
print(response)
4045
print(client.extract_text_or_completion_object(response))
@@ -93,21 +98,23 @@ def test_chat_completion():
9398
def test_completion():
9499
config_list = config_list_openai_aoai(KEY_LOC)
95100
client = OpenAIWrapper(config_list=config_list)
96-
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
101+
model = "gpt-3.5-turbo-instruct"
102+
response = client.create(prompt="1+1=", model=model)
97103
print(response)
98104
print(client.extract_text_or_completion_object(response))
99105

100106

101107
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
102108
@pytest.mark.parametrize(
103-
"cache_seed, model",
109+
"cache_seed",
104110
[
105-
(None, "gpt-3.5-turbo-instruct"),
106-
(42, "gpt-3.5-turbo-instruct"),
111+
None,
112+
42,
107113
],
108114
)
109-
def test_cost(cache_seed, model):
115+
def test_cost(cache_seed):
110116
config_list = config_list_openai_aoai(KEY_LOC)
117+
model = "gpt-3.5-turbo-instruct"
111118
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
112119
response = client.create(prompt="1+3=", model=model)
113120
print(response.cost)
@@ -117,7 +124,8 @@ def test_cost(cache_seed, model):
117124
def test_usage_summary():
118125
config_list = config_list_openai_aoai(KEY_LOC)
119126
client = OpenAIWrapper(config_list=config_list)
120-
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=None)
127+
model = "gpt-3.5-turbo-instruct"
128+
response = client.create(prompt="1+3=", model=model, cache_seed=None)
121129

122130
# usage should be recorded
123131
assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
@@ -138,15 +146,15 @@ def test_usage_summary():
138146
assert client.total_usage_summary is None, "total_usage_summary should be None"
139147

140148
# actual usage and all usage should be different
141-
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=42)
149+
response = client.create(prompt="1+3=", model=model, cache_seed=42)
142150
assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
143151
assert client.actual_usage_summary is None, "No actual cost should be recorded"
144152

145153

146154
if __name__ == "__main__":
147-
test_aoai_chat_completion()
148-
test_oai_tool_calling_extraction()
149-
test_chat_completion()
155+
# test_aoai_chat_completion()
156+
# test_oai_tool_calling_extraction()
157+
# test_chat_completion()
150158
test_completion()
151-
# test_cost()
152-
test_usage_summary()
159+
# # test_cost()
160+
# test_usage_summary()

test/oai/test_client_stream.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,9 @@ def test_chat_tools_stream() -> None:
286286
def test_completion_stream() -> None:
287287
config_list = config_list_openai_aoai(KEY_LOC)
288288
client = OpenAIWrapper(config_list=config_list)
289-
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
289+
# Azure can't have dot in model/deployment name
290+
model = "gpt-35-turbo-instruct" if config_list[0].get("api_type") == "azure" else "gpt-3.5-turbo-instruct"
291+
response = client.create(prompt="1+1=", model=model, stream=True)
290292
print(response)
291293
print(client.extract_text_or_completion_object(response))
292294

0 commit comments

Comments
 (0)