Skip to content

Commit 34a0e99

Browse files
afourneyvictordibia
authored andcommitted
Updated the azure client to support AAD auth. (#2879)
1 parent 906c4c7 commit 34a0e99

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

autogen/logger/sqlite_logger.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,16 @@ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> N
236236

237237
args = to_dict(
238238
init_args,
239-
exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint"),
239+
exclude=(
240+
"self",
241+
"__class__",
242+
"api_key",
243+
"organization",
244+
"base_url",
245+
"azure_endpoint",
246+
"azure_ad_token",
247+
"azure_ad_token_provider",
248+
),
240249
no_recursive=(Agent,),
241250
)
242251

@@ -301,7 +310,17 @@ def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLM
301310
return
302311

303312
args = to_dict(
304-
init_args, exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint")
313+
init_args,
314+
exclude=(
315+
"self",
316+
"__class__",
317+
"api_key",
318+
"organization",
319+
"base_url",
320+
"azure_endpoint",
321+
"azure_ad_token",
322+
"azure_ad_token_provider",
323+
),
305324
)
306325

307326
query = """
@@ -323,7 +342,17 @@ def log_new_client(
323342
return
324343

325344
args = to_dict(
326-
init_args, exclude=("self", "__class__", "api_key", "organization", "base_url", "azure_endpoint")
345+
init_args,
346+
exclude=(
347+
"self",
348+
"__class__",
349+
"api_key",
350+
"organization",
351+
"base_url",
352+
"azure_endpoint",
353+
"azure_ad_token",
354+
"azure_ad_token_provider",
355+
),
327356
)
328357

329358
query = """

autogen/oai/client.py

+8
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,14 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st
407407
openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
408408
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
409409

410+
# Create a default Azure token provider if requested
411+
if openai_config.get("azure_ad_token_provider") == "DEFAULT":
412+
import azure.identity
413+
414+
openai_config["azure_ad_token_provider"] = azure.identity.get_bearer_token_provider(
415+
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
416+
)
417+
410418
def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None:
411419
"""Create a client with the given config to override openai_config,
412420
after removing extra kwargs.

autogen/oai/openai_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from openai.types.beta.assistant import Assistant
1414
from packaging.version import parse
1515

16-
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
16+
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
1717
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
1818
OAI_PRICE1K = {
1919
# https://openai.com/api/pricing/

0 commit comments

Comments
 (0)