Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions sdk/identity/azure-identity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,57 @@
# Key concepts

# Examples
Shortest path to an access token:
Authenticating as a service principal:
```py
# using a client secret
from azure.identity import ClientSecretCredential

credential = ClientSecretCredential(client_id, secret, tenant_id)

# all credentials implement get_token
token = credential.get_token(scopes=["https://vault.azure.net/.default"])

# using a certificate requires a thumbprint and PEM-encoded private key
from azure.identity import CertificateCredential
with open("private-key.pem") as f:
private_key = f.read()
credential = CertificateCredential(client_id, tenant_id, private_key, thumbprint)
```

Authenticating via environment variables:
```py
from azure.identity import EnvironmentCredential

# will authenticate with client secret or certificate,
# depending on which environment variables are set
# (see constants.py for expected variable names)
credential = EnvironmentCredential()
token = credential.get_token(scopes=["https://vault.azure.net/.default"])
```

Chaining together multiple credentials:
```py
from azure.identity import TokenCredentialChain

# default credentials are environment then managed identity
credential_chain = TokenCredentialChain.default()

scopes = ["https://vault.azure.net/.default"]
# the chain has a get_token method like all credentials
token = credential_chain.get_token(scopes) # try each credential in order, return the first token
```

Authenticating from a service client:
```py
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import BearerTokenCredentialPolicy

credential_chain = TokenCredentialChain.default()
scopes = ["https://vault.azure.net/.default"]

# BearerTokenCredentialPolicy gets tokens as necessary, adds appropriate auth headers to requests
policies = [BearerTokenCredentialPolicy(credential=credential_chain, scopes=scopes)]
pipeline = Pipeline(transport=some_transport, policies=policies)
```
# Troubleshooting

# Next steps
Expand Down
35 changes: 31 additions & 4 deletions sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,40 @@
# license information.
# --------------------------------------------------------------------------
from .exceptions import AuthenticationError
from .credentials import ClientSecretCredential, TokenCredentialChain
from .credentials import (
CertificateCredential,
ClientSecretCredential,
EnvironmentCredential,
ManagedIdentityCredential,
TokenCredentialChain,
)

__all__ = ["AuthenticationError", "ClientSecretCredential", "TokenCredentialChain"]
__all__ = [
"AuthenticationError",
"CertificateCredential",
"ClientSecretCredential",
"EnvironmentCredential",
"ManagedIdentityCredential",
"TokenCredentialChain",
]

try:
from .aio import AsyncClientSecretCredential, AsyncTokenCredentialChain
from .aio import (
AsyncCertificateCredential,
AsyncClientSecretCredential,
AsyncEnvironmentCredential,
AsyncManagedIdentityCredential,
AsyncTokenCredentialChain,
)

__all__.extend(["AsyncClientSecretCredential", "AsyncTokenCredentialChain"])
__all__.extend(
[
"AsyncCertificateCredential",
"AsyncClientSecretCredential",
"AsyncEnvironmentCredential",
"AsyncManagedIdentityCredential",
"AsyncTokenCredentialChain",
]
)
except SyntaxError:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from time import time

from azure.core import Configuration, HttpRequest
from azure.core.pipeline import Pipeline
from azure.core.pipeline import Pipeline, PipelineRequest
from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, RetryPolicy
from azure.core.pipeline.transport import RequestsTransport
from azure.core.pipeline.transport import HttpTransport, RequestsTransport
from msal import TokenCache

from .exceptions import AuthenticationError
Expand All @@ -18,16 +18,21 @@
except ImportError:
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Any, Iterable, Mapping, Optional
from typing import Any, Dict, Iterable, Mapping, Optional
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.policies import HTTPPolicy


class _AuthnClientBase(object):
class AuthnClientBase(object):
"""Sans I/O authentication client methods"""

def __init__(self, auth_url, **kwargs):
# type: (str, Mapping[str, Any]) -> None
if not auth_url:
raise ValueError("auth_url")
super(_AuthnClientBase, self).__init__(**kwargs)
self._cache = TokenCache()
raise ValueError("auth_url should be the URL of an OAuth endpoint")
super(AuthnClientBase, self).__init__()
self._auth_url = auth_url
self._cache = TokenCache()

def get_cached_token(self, scopes):
# type: (Iterable[str]) -> Optional[str]
Expand All @@ -38,16 +43,8 @@ def get_cached_token(self, scopes):
return token["secret"]
return None

def _prepare_request(self, method="POST", form_data=None, params=None):
request = HttpRequest(method, self._auth_url)
if form_data:
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request.set_formdata_body(form_data)
if params:
request.format_parameters(params)
return request

def _deserialize_and_cache_token(self, response, scopes):
# type: (PipelineResponse, Iterable[str]) -> str
try:
if "deserialized_data" in response.context:
payload = response.context["deserialized_data"]
Expand All @@ -61,18 +58,31 @@ def _deserialize_and_cache_token(self, response, scopes):
except Exception as ex:
raise AuthenticationError("Authentication failed: {}".format(str(ex)))

def _prepare_request(self, method="POST", form_data=None, params=None):
# type: (Optional[str], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> HttpRequest
request = HttpRequest(method, self._auth_url)
if form_data:
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request.set_formdata_body(form_data)
if params:
request.format_parameters(params)
return request


class AuthnClient(AuthnClientBase):
"""Synchronous authentication client"""

class AuthnClient(_AuthnClientBase):
def __init__(self, auth_url, config=None, policies=None, transport=None, **kwargs):
# type: (str, Optional[Configuration], Optional[Iterable[HTTPPolicy]], Optional[HttpTransport], Mapping[str, Any]) -> None
config = config or self.create_config(**kwargs)
# TODO: ContentDecodePolicy doesn't accept kwargs
policies = policies or [ContentDecodePolicy(), config.logging_policy, config.retry_policy]
if not transport:
transport = RequestsTransport(configuration=config)
self._pipeline = Pipeline(transport=transport, policies=policies)
super(AuthnClient, self).__init__(auth_url, **kwargs)

def request_token(self, scopes, method="POST", form_data=None, params=None):
# type: (Iterable[str], Optional[str], Optional[Mapping[str, str]], Optional[Dict[str, str]]) -> str
request = self._prepare_request(method, form_data, params)
response = self._pipeline.run(request, stream=False)
token = self._deserialize_and_cache_token(response, scopes)
Expand All @@ -83,5 +93,5 @@ def create_config(**kwargs):
# type: (Mapping[str, Any]) -> Configuration
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = RetryPolicy(retry_on_status_codes=[404, 429] + [x for x in range(500, 600)], **kwargs)
config.retry_policy = RetryPolicy(retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs)
return config
52 changes: 52 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from msal.oauth2cli import JwtSigner

from .constants import OAUTH_ENDPOINT

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Mapping


class ClientSecretCredentialBase(object):
def __init__(self, client_id, secret, tenant_id, **kwargs):
# type: (str, str, str, Mapping[str, Any]) -> None
if not client_id:
raise ValueError("client_id should be the id of an Azure Active Directory application")
if not secret:
raise ValueError("secret should be an Azure Active Directory application's client secret")
if not tenant_id:
raise ValueError("tenant_id should be an Azure Active Directory tenant's id (also called its 'directory id')")
self._form_data = {"client_id": client_id, "client_secret": secret, "grant_type": "client_credentials"}
super(ClientSecretCredentialBase, self).__init__()


class CertificateCredentialBase(object):
def __init__(self, client_id, tenant_id, certificate_path, **kwargs):
# type: (str, str, str, Mapping[str, Any]) -> None
if not certificate_path:
# TODO: support PFX
raise ValueError("certificate_path must be the path to a PEM-encoded private key file")

super(CertificateCredentialBase, self).__init__()
auth_url = OAUTH_ENDPOINT.format(tenant_id)

with open(certificate_path) as pem:
private_key = pem.read()
signer = JwtSigner(private_key, "RS256")
client_assertion = signer.sign_assertion(audience=auth_url, issuer=client_id)
self._form_data = {
"client_assertion": client_assertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": client_id,
"grant_type": "client_credentials",
}
16 changes: 14 additions & 2 deletions sdk/identity/azure-identity/azure/identity/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from .credentials import AsyncClientSecretCredential, AsyncTokenCredentialChain
from .credentials import (
AsyncCertificateCredential,
AsyncClientSecretCredential,
AsyncEnvironmentCredential,
AsyncManagedIdentityCredential,
AsyncTokenCredentialChain,
)

__all__ = ["AsyncClientSecretCredential", "AsyncTokenCredentialChain"]
__all__ = [
"AsyncCertificateCredential",
"AsyncClientSecretCredential",
"AsyncEnvironmentCredential",
"AsyncManagedIdentityCredential",
"AsyncTokenCredentialChain",
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import Any, Iterable, Mapping, Optional
from typing import Any, Dict, Iterable, Mapping, Optional

from azure.core import Configuration
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies import AsyncRetryPolicy, ContentDecodePolicy, HTTPPolicy, NetworkTraceLoggingPolicy
from azure.core.pipeline.transport import AsyncHttpTransport
from azure.core.pipeline.transport.requests_asyncio import AsyncioRequestsTransport

from ..authn_client import _AuthnClientBase
from .._authn_client import AuthnClientBase


class AsyncAuthnClient(_AuthnClientBase):
class AsyncAuthnClient(AuthnClientBase):
"""Async authentication client"""

def __init__(
self,
auth_url: str,
Expand All @@ -24,7 +26,6 @@ def __init__(
**kwargs: Mapping[str, Any]
) -> None:
config = config or self.create_config(**kwargs)
# TODO: ContentDecodePolicy doesn't accept kwargs
policies = policies or [ContentDecodePolicy(), config.logging_policy, config.retry_policy]
if not transport:
transport = AsyncioRequestsTransport(configuration=config)
Expand All @@ -36,7 +37,7 @@ async def request_token(
scopes: Iterable[str],
method: Optional[str] = "POST",
form_data: Optional[Mapping[str, str]] = None,
params: Optional[Mapping[str, str]] = None,
params: Optional[Dict[str, str]] = None,
) -> str:
request = self._prepare_request(method, form_data, params)
response = await self._pipeline.run(request, stream=False)
Expand All @@ -48,6 +49,6 @@ def create_config(**kwargs: Mapping[str, Any]) -> Configuration:
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = AsyncRetryPolicy(
retry_on_status_codes=[404, 429] + [x for x in range(500, 600)], **kwargs
retry_on_status_codes=[404, 429] + list(range(500, 600)), **kwargs
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

404 [](start = 35, length = 3)

Why are you retrying on 404?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right? According to the IMDS docs the endpoint returns 404 when it's updating, and clients should retry.

Copy link
Member

@johanste johanste Jun 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of having to create a list with >100 items. There are some things we can do to improve. Just off the top of my head, something like:

class StatusCodeRange(object):

    def __init__(self, min_value, max_value):
        self.min_value = min_value
        self.max_value = max_value

    def __eq__(self, other):
        return self.min_value <= other and other <= self.max_value


scr = StatusCodeRange(500, 600)

assert 503 in [404, 429, scr]
assert 499 not in [1, 2, scr]

It does bastardize the eq method, though, so another version would be to change the simple in check in the retry policy to understand ranges as well...

class SuperRange(object):

    def __init__(self, *values):
        self.values = values

    def __contains__(self, value):
        for value_in_range in self.values:
            try:
                if value_in_range[0] <= value and value <= value_in_range[1]:
                    return True
            except TypeError:
                if value_in_range == value:
                    return True
        return False


sur = SuperRange(404, 429, (500, 600))


assert 503 in sur
assert 499 not in sur

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right? According to the IMDS docs the endpoint returns 404 when it's updating, and clients should retry.

Per the linked documentation (great find, btw :)), we should make sure our retry policy follows the suggested timeouts by default. We are trying to make sure that our libraries "do the right thing" (tm).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have done. I put #5628 on the backlog to track retry code range support.

)
return config
Loading