Skip to content

Commit

Permalink
Support AAD pass-through auth
Browse files Browse the repository at this point in the history
  • Loading branch information
dargilco committed Sep 19, 2024
1 parent 05915eb commit d2d97ac
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 33 deletions.
12 changes: 9 additions & 3 deletions sdk/ai/azure-ai-client/azure/ai/client/operations/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class ConnectionsOperations(ConnectionsOperationsGenerated):

def get_credential(
self, *, connection_name: str = None, **kwargs
self, *, connection_name: str | None = None, **kwargs
) -> Tuple[Union[str, AzureKeyCredential, TokenCredential], str]:

if connection_name == "":
Expand Down Expand Up @@ -52,8 +52,14 @@ def get_credential(
return credential, endpoint
else:
raise ValueError("Unknown connection category `{response.properties.category}`.")
# elif response.properties.auth_type == AuthType.AAD:
# credentials = self._config.credential
elif response.properties.auth_type == AuthType.AAD:
if response.properties.category == ConnectionCategory.AZURE_OPEN_AI:
credential = self._config.credential
return credential, endpoint
elif response.properties.category == ConnectionCategory.SERVERLESS:
raise ValueError("Serverless API does not support AAD authentication.")
else:
raise ValueError("Unknown connection category `{response.properties.category}`.")
# elif response.properties.auth_type == AuthType.SAS:
# credentials =
else:
Expand Down
70 changes: 40 additions & 30 deletions sdk/ai/azure-ai-client/samples/sample_get_aoai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
from azure.ai.client import AzureAIClient
from openai import AzureOpenAI
from azure.identity import DefaultAzureCredential
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

ai_client = AzureAIClient(
credential=DefaultAzureCredential(),
Expand All @@ -23,38 +23,48 @@
logging_enable=True,
)

key, endpoint = ai_client.connections.get_credential(connection_name=os.environ["AI_STUDIO_CONNECTION_1"])
use_key_auth = False

client = AzureOpenAI(
api_key=key,
azure_endpoint=endpoint,
api_version="2024-08-01-preview", # See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference-preview#api-specs
)
if use_key_auth:
key, endpoint = ai_client.connections.get_credential(connection_name=os.environ["AI_STUDIO_CONNECTION_1"])

completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": "How many feet are in a mile?",
},
],
)
client = AzureOpenAI(
api_key=key,
azure_endpoint=endpoint,
api_version="2024-08-01-preview", # See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference-preview#api-specs
)

print(f"\n\n===============> {completion.choices[0].message.content}\n\n")
completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": "How many feet are in a mile?",
},
],
)

exit()
print(f"\n\n===============> {completion.choices[0].message.content}\n\n")

# Get an AOAI client for an AAD-auth connection:
aoai_client = client.get_azure_openai_client(connection_name=os.environ["AI_STUDIO_CONNECTION_2"])
# Use AAD passthrough auth:
else:

completion = aoai_client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": "What's the distance from earth to the moon in miles?",
},
],
)
print(f"\n\n===============> {completion.choices[0].message.content}\n\n")
credential, endpoint = ai_client.connections.get_credential(connection_name=os.environ["AI_STUDIO_CONNECTION_1"])

client = AzureOpenAI(
# See https://learn.microsoft.com/python/api/azure-identity/azure.identity?view=azure-python#azure-identity-get-bearer-token-provider
azure_ad_token_provider=get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default"),
azure_endpoint=endpoint,
api_version="2024-08-01-preview", # See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference-preview#api-specs
)

completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": "What's the distance from earth to the moon in miles?",
},
],
)
print(f"\n\n===============> {completion.choices[0].message.content}\n\n")

0 comments on commit d2d97ac

Please sign in to comment.