Skip to content
7 changes: 3 additions & 4 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import Iterable # pylint:disable=unused-import
from typing_extensions import Protocol


class SupportsGetToken(Protocol):
class TokenCredential(Protocol):
"""Protocol for classes able to provide OAuth tokens.

:param str scopes: Lets you specify the type of access needed.
"""
# pylint:disable=too-few-public-methods
def get_token(self, scopes):
# type: (Iterable[str]) -> str
def get_token(self, *scopes):
# type: (*str) -> str
pass
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Dict, Iterable, Mapping
from azure.core.credentials import SupportsGetToken
from azure.core.credentials import TokenCredential
from azure.core.pipeline import PipelineRequest, PipelineResponse


Expand All @@ -22,12 +22,12 @@ class _BearerTokenCredentialPolicyBase(object):
"""Base class for a Bearer Token Credential Policy.

:param credential: The credential.
:type credential: ~azure.core.SupportsGetToken
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, scopes, **kwargs): # pylint:disable=unused-argument
# type: (SupportsGetToken, Iterable[str], Mapping[str, Any]) -> None
# type: (TokenCredential, Iterable[str], Mapping[str, Any]) -> None
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
Expand All @@ -47,7 +47,7 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.SupportsGetToken
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, AsyncHT
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.SupportsGetToken
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

Expand Down
15 changes: 14 additions & 1 deletion sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@
TokenCredentialChain,
)


class DefaultAzureCredential(TokenCredentialChain):
"""default credential is environment followed by MSI/IMDS"""

def __init__(self, **kwargs):
super(DefaultAzureCredential, self).__init__(
EnvironmentCredential(**kwargs), ManagedIdentityCredential(**kwargs)
)


__all__ = [
"AuthenticationError",
"CertificateCredential",
"ClientSecretCredential",
"DefaultAzureCredential",
"EnvironmentCredential",
"ManagedIdentityCredential",
"TokenCredentialChain",
Expand All @@ -25,6 +36,7 @@
from .aio import (
AsyncCertificateCredential,
AsyncClientSecretCredential,
AsyncDefaultAzureCredential,
AsyncEnvironmentCredential,
AsyncManagedIdentityCredential,
AsyncTokenCredentialChain,
Expand All @@ -34,10 +46,11 @@
[
"AsyncCertificateCredential",
"AsyncClientSecretCredential",
"AsyncDefaultAzureCredential",
"AsyncEnvironmentCredential",
"AsyncManagedIdentityCredential",
"AsyncTokenCredentialChain",
]
)
except SyntaxError:
except (ImportError, SyntaxError):
pass
22 changes: 15 additions & 7 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# -------------------------------------------------------------------------
from time import time

from azure.core import Configuration, HttpRequest
Expand Down Expand Up @@ -51,16 +51,24 @@ def _deserialize_and_cache_token(self, response, scopes):
else:
payload = response.http_response.text()
token = payload["access_token"]

# these values are strings in IMDS responses but msal.TokenCache requires they be integers
# https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/55
if payload.get("expires_in"):
payload["expires_in"] = int(payload["expires_in"])
if payload.get("ext_expires_in"):
payload["ext_expires_in"] = int(payload["ext_expires_in"])

self._cache.add({"response": payload, "scope": scopes})
return token
except KeyError:
raise AuthenticationError("Unexpected authentication response: {}".format(payload))
except Exception as ex:
raise AuthenticationError("Authentication failed: {}".format(str(ex)))

def _prepare_request(self, method="POST", form_data=None, params=None):
# type: (Optional[str], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
request = HttpRequest(method, self._auth_url)
def _prepare_request(self, method="POST", headers=None, form_data=None, params=None):
# type: (Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
request = HttpRequest(method, self._auth_url, headers=headers)
if form_data:
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request.set_formdata_body(form_data)
Expand All @@ -81,9 +89,9 @@ def __init__(self, auth_url, config=None, policies=None, transport=None, **kwarg
self._pipeline = Pipeline(transport=transport, policies=policies)
super(AuthnClient, self).__init__(auth_url, **kwargs)

def request_token(self, scopes, method="POST", form_data=None, params=None):
# type: (Iterable[str], Optional[str], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> str
request = self._prepare_request(method, form_data, params)
def request_token(self, scopes, method="POST", headers=None, form_data=None, params=None):
# type: (Iterable[str], Optional[str], Optional[Mapping[str, str]], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> str
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
response = self._pipeline.run(request, stream=False)
token = self._deserialize_and_cache_token(response, scopes)
return token
Expand Down
4 changes: 2 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# --------------------------------------------------------------------------
from msal.oauth2cli import JwtSigner

from .constants import OAUTH_ENDPOINT
from .constants import Endpoints

try:
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(self, client_id, tenant_id, certificate_path, **kwargs):
raise ValueError("certificate_path must be the path to a PEM-encoded private key file")

super(CertificateCredentialBase, self).__init__()
auth_url = OAUTH_ENDPOINT.format(tenant_id)
auth_url = Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id)

with open(certificate_path) as pem:
private_key = pem.read()
Expand Down
105 changes: 105 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# ------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# ------------------------------------------------------------------------
import os

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Dict, Optional

from azure.core import Configuration
from azure.core.pipeline.policies import ContentDecodePolicy, HeadersPolicy, NetworkTraceLoggingPolicy, RetryPolicy

from ._authn_client import AuthnClient
from .constants import Endpoints, MSI_ENDPOINT, MSI_SECRET
from .exceptions import AuthenticationError


class ImdsCredential:
"""Authenticates with a managed identity via the IMDS endpoint"""

def __init__(self, config=None, **kwargs):
# type: (Optional[Configuration], Dict[str, Any]) -> None
config = config or self.create_config(**kwargs)
policies = [config.header_policy, ContentDecodePolicy(), config.logging_policy, config.retry_policy]
self._client = AuthnClient(Endpoints.IMDS, config, policies, **kwargs)

@staticmethod
def create_config(**kwargs):
# type: (Dict[str, str]) -> Configuration
timeout = kwargs.pop("connection_timeout", 2)
config = Configuration(connection_timeout=timeout, **kwargs)
config.header_policy = HeadersPolicy(base_headers={"Metadata": "true"}, **kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
retries = kwargs.pop("retry_total", 5)
config.retry_policy = RetryPolicy(
retry_total=retries, retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs
)
return config

def get_token(self, *scopes):
# type: (*str) -> str
if len(scopes) != 1:
raise ValueError("this credential supports one scope per request")
token = self._client.get_cached_token(scopes)
if not token:
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[:-len("/.default")]
token = self._client.request_token(
scopes, method="GET", params={"api-version": "2018-02-01", "resource": resource}
)
return token


class MsiCredential:
"""Authenticates via the MSI endpoint"""

def __init__(self, config=None, **kwargs):
# type: (Optional[Configuration], Dict[str, Any]) -> None
config = config or self.create_config(**kwargs)
policies = [ContentDecodePolicy(), config.retry_policy, config.logging_policy]
endpoint = os.environ.get(MSI_ENDPOINT)
if not endpoint:
raise ValueError("expected environment variable {} has no value".format(MSI_ENDPOINT))
self._client = AuthnClient(endpoint, config, policies, **kwargs)

@staticmethod
def create_config(**kwargs):
# type: (Dict[str, str]) -> Configuration
timeout = kwargs.pop("connection_timeout", 2)
config = Configuration(connection_timeout=timeout, **kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
retries = kwargs.pop("retry_total", 5)
config.retry_policy = RetryPolicy(
retry_total=retries, retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs
)
return config

def get_token(self, *scopes):
# type: (*str) -> str
if len(scopes) != 1:
raise ValueError("this credential supports only one scope per request")
token = self._client.get_cached_token(scopes)
if not token:
secret = os.environ.get(MSI_SECRET)
if not secret:
raise AuthenticationError("{} environment variable has no value".format(MSI_SECRET))
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[:-len("/.default")]
# TODO: support user-assigned client id
token = self._client.request_token(
scopes,
method="GET",
headers={"secret": secret},
params={"api-version": "2017-09-01", "resource": resource},
)
return token
13 changes: 11 additions & 2 deletions sdk/identity/azure-identity/azure/identity/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# ------------------------------------------------------------------------
from .credentials import (
AsyncCertificateCredential,
AsyncClientSecretCredential,
Expand All @@ -11,9 +11,18 @@
AsyncTokenCredentialChain,
)


class AsyncDefaultAzureCredential(AsyncTokenCredentialChain):
"""default credential is environment followed by MSI/IMDS"""

def __init__(self, **kwargs):
super().__init__(AsyncEnvironmentCredential(**kwargs), AsyncManagedIdentityCredential(**kwargs))


__all__ = [
"AsyncCertificateCredential",
"AsyncClientSecretCredential",
"AsyncDefaultAzureCredential",
"AsyncEnvironmentCredential",
"AsyncManagedIdentityCredential",
"AsyncTokenCredentialChain",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ async def request_token(
self,
scopes: Iterable[str],
method: Optional[str] = "POST",
headers: Optional[Mapping[str, str]] = None,
form_data: Optional[Mapping[str, str]] = None,
params: Optional[Dict[str, str]] = None,
) -> str:
request = self._prepare_request(method, form_data, params)
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
response = await self._pipeline.run(request, stream=False)
token = self._deserialize_and_cache_token(response, scopes)
return token
Expand Down
Loading