11
11
12
12
from autogen .oai import completion
13
13
14
- from autogen .oai .openai_utils import get_key , OAI_PRICE1K
14
+ from autogen .oai .openai_utils import DEFAULT_AZURE_API_VERSION , get_key , OAI_PRICE1K
15
15
from autogen .token_count_utils import count_token
16
16
from autogen ._pydantic import model_dump
17
17
21
21
except ImportError :
22
22
ERROR : Optional [ImportError ] = ImportError ("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper." )
23
23
OpenAI = object
24
+ AzureOpenAI = object
24
25
else :
25
26
# raises exception if openai>=1 is installed and something is wrong with imports
26
- from openai import OpenAI , APIError , __version__ as OPENAIVERSION
27
+ from openai import OpenAI , AzureOpenAI , APIError , __version__ as OPENAIVERSION
27
28
from openai .resources import Completions
28
29
from openai .types .chat import ChatCompletion
29
30
from openai .types .chat .chat_completion import ChatCompletionMessage , Choice # type: ignore [attr-defined]
@@ -52,8 +53,18 @@ class OpenAIWrapper:
52
53
"""A wrapper class for openai client."""
53
54
54
55
cache_path_root : str = ".cache"
55
- extra_kwargs = {"cache_seed" , "filter_func" , "allow_format_str_template" , "context" , "api_version" , "tags" }
56
+ extra_kwargs = {
57
+ "cache_seed" ,
58
+ "filter_func" ,
59
+ "allow_format_str_template" ,
60
+ "context" ,
61
+ "api_version" ,
62
+ "api_type" ,
63
+ "tags" ,
64
+ }
56
65
openai_kwargs = set (inspect .getfullargspec (OpenAI .__init__ ).kwonlyargs )
66
+ aopenai_kwargs = set (inspect .getfullargspec (AzureOpenAI .__init__ ).kwonlyargs )
67
+ openai_kwargs = openai_kwargs | aopenai_kwargs
57
68
total_usage_summary : Optional [Dict [str , Any ]] = None
58
69
actual_usage_summary : Optional [Dict [str , Any ]] = None
59
70
@@ -105,46 +116,10 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base
105
116
self ._clients = [self ._client (extra_kwargs , openai_config )]
106
117
self ._config_list = [extra_kwargs ]
107
118
108
- def _process_for_azure (
109
- self , config : Dict [str , Any ], extra_kwargs : Dict [str , Any ], segment : str = "default"
110
- ) -> None :
111
- # deal with api_version
112
- query_segment = f"{ segment } _query"
113
- headers_segment = f"{ segment } _headers"
114
- api_version = extra_kwargs .get ("api_version" )
115
- if api_version is not None and query_segment not in config :
116
- config [query_segment ] = {"api-version" : api_version }
117
- if segment == "default" :
118
- # remove the api_version from extra_kwargs
119
- extra_kwargs .pop ("api_version" )
120
- if segment == "extra" :
121
- return
122
- # deal with api_type
123
- api_type = extra_kwargs .get ("api_type" )
124
- if api_type is not None and api_type .startswith ("azure" ) and headers_segment not in config :
125
- api_key = config .get ("api_key" , os .environ .get ("AZURE_OPENAI_API_KEY" ))
126
- config [headers_segment ] = {"api-key" : api_key }
127
- # remove the api_type from extra_kwargs
128
- extra_kwargs .pop ("api_type" )
129
- # deal with model
130
- model = extra_kwargs .get ("model" )
131
- if model is None :
132
- return
133
- if "gpt-3.5" in model :
134
- # hack for azure gpt-3.5
135
- extra_kwargs ["model" ] = model = model .replace ("gpt-3.5" , "gpt-35" )
136
- base_url = config .get ("base_url" )
137
- if base_url is None :
138
- raise ValueError ("to use azure openai api, base_url must be specified." )
139
- suffix = f"/openai/deployments/{ model } "
140
- if not base_url .endswith (suffix ):
141
- config ["base_url" ] += suffix [1 :] if base_url .endswith ("/" ) else suffix
142
-
143
119
def _separate_openai_config (self , config : Dict [str , Any ]) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
144
120
"""Separate the config into openai_config and extra_kwargs."""
145
121
openai_config = {k : v for k , v in config .items () if k in self .openai_kwargs }
146
122
extra_kwargs = {k : v for k , v in config .items () if k not in self .openai_kwargs }
147
- self ._process_for_azure (openai_config , extra_kwargs )
148
123
return openai_config , extra_kwargs
149
124
150
125
def _separate_create_config (self , config : Dict [str , Any ]) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
@@ -156,10 +131,22 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any
156
131
def _client (self , config : Dict [str , Any ], openai_config : Dict [str , Any ]) -> OpenAI :
157
132
"""Create a client with the given config to override openai_config,
158
133
after removing extra kwargs.
134
+
135
+ For Azure models/deployment names there's a convenience modification of model removing dots in
136
+ the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
137
+ "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
138
+ from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
159
139
"""
160
140
openai_config = {** openai_config , ** {k : v for k , v in config .items () if k in self .openai_kwargs }}
161
- self ._process_for_azure (openai_config , config )
162
- client = OpenAI (** openai_config )
141
+ api_type = config .get ("api_type" )
142
+ if api_type is not None and api_type .startswith ("azure" ):
143
+ openai_config ["azure_deployment" ] = openai_config .get ("azure_deployment" , config .get ("model" ))
144
+ if openai_config ["azure_deployment" ] is not None :
145
+ openai_config ["azure_deployment" ] = openai_config ["azure_deployment" ].replace ("." , "" )
146
+ openai_config ["azure_endpoint" ] = openai_config .get ("azure_endpoint" , openai_config .pop ("base_url" , None ))
147
+ client = AzureOpenAI (** openai_config )
148
+ else :
149
+ client = OpenAI (** openai_config )
163
150
return client
164
151
165
152
@classmethod
@@ -242,8 +229,9 @@ def yes_or_no_filter(context, response):
242
229
full_config = {** config , ** self ._config_list [i ]}
243
230
# separate the config into create_config and extra_kwargs
244
231
create_config , extra_kwargs = self ._separate_create_config (full_config )
245
- # process for azure
246
- self ._process_for_azure (create_config , extra_kwargs , "extra" )
232
+ api_type = extra_kwargs .get ("api_type" )
233
+ if api_type and api_type .startswith ("azure" ) and "model" in create_config :
234
+ create_config ["model" ] = create_config ["model" ].replace ("." , "" )
247
235
# construct the create params
248
236
params = self ._construct_create_params (create_config , extra_kwargs )
249
237
# get the cache_seed, filter_func and context
0 commit comments