diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 78bdc2c21423..d9f9b4d04eec 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -1,6 +1,19 @@ # Release History ## 1.4.0b3 (Unreleased) +- First preview of new API for authenticating users with `DeviceCodeCredential` + and `InteractiveBrowserCredential` + - new method `authenticate` interactively authenticates a user, returns a + serializable `AuthenticationRecord` + - new constructor keyword arguments + - `authentication_record` enables initializing a credential with an + `AuthenticationRecord` from a prior authentication + - `disable_automatic_authentication=True` configures the credential to raise + `AuthenticationRequiredError` when interactive authentication is necessary + to acquire a token rather than immediately begin that authentication + - `enable_persistent_cache=True` configures these credentials to use a + persistent cache on supported platforms (in this release, Windows only). + By default they cache in memory only. ## 1.4.0b2 (2020-04-06) diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 7648a37ffa19..3fbc78c57bfa 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -4,7 +4,8 @@ # ------------------------------------ """Credentials for Azure SDK clients.""" -from ._exceptions import CredentialUnavailableError +from ._auth_record import AuthenticationRecord +from ._exceptions import AuthenticationRequiredError, CredentialUnavailableError from ._constants import KnownAuthorities from ._credentials import ( AuthorizationCodeCredential, @@ -22,6 +23,8 @@ __all__ = [ + "AuthenticationRecord", + "AuthenticationRequiredError", "AuthorizationCodeCredential", "CertificateCredential", "ChainedTokenCredential", diff --git a/sdk/identity/azure-identity/azure/identity/_auth_record.py b/sdk/identity/azure-identity/azure/identity/_auth_record.py new file mode 100644 index 000000000000..968e5e8a5588 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_auth_record.py @@ -0,0 +1,72 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import json + + +class AuthenticationRecord(object): + """A record which can initialize :class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`""" + + def __init__(self, tenant_id, client_id, authority, home_account_id, username): + # type: (str, str, str, str, str) -> None + self._authority = authority + self._client_id = client_id + self._home_account_id = home_account_id + self._tenant_id = tenant_id + self._username = username + + @property + def authority(self): + # type: () -> str + return self._authority + + @property + def client_id(self): + # type: () -> str + return self._client_id + + @property + def home_account_id(self): + # type: () -> str + return self._home_account_id + + @property + def tenant_id(self): + # type: () -> str + return self._tenant_id + + @property + def username(self): + # type: () -> str + """The authenticated user's username""" + return self._username + + @classmethod + def deserialize(cls, json_string): + # type: (str) -> AuthenticationRecord + """Deserialize a record from JSON""" + + deserialized = json.loads(json_string) + + return cls( + authority=deserialized["authority"], + client_id=deserialized["client_id"], + home_account_id=deserialized["home_account_id"], + tenant_id=deserialized["tenant_id"], + username=deserialized["username"], + ) + + def serialize(self): + # type: () -> str + """Serialize the record to JSON""" + + record = { + "authority": self._authority, + "client_id": self._client_id, + "home_account_id": self._home_account_id, + "tenant_id": self._tenant_id, + "username": self._username, + } + + return json.dumps(record) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py index b01abef7d714..cec00aaec07d 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py @@ -3,16 +3,14 @@ # Licensed under the MIT License. # ------------------------------------ import socket -import time import uuid import webbrowser -from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from .. import CredentialUnavailableError from .._constants import AZURE_CLI_CLIENT_ID -from .._internal import AuthCodeRedirectServer, PublicClientCredential, wrap_exceptions +from .._internal import AuthCodeRedirectServer, InteractiveCredential, wrap_exceptions try: from typing import TYPE_CHECKING @@ -24,7 +22,7 @@ from typing import Any, List, Mapping -class InteractiveBrowserCredential(PublicClientCredential): +class InteractiveBrowserCredential(InteractiveCredential): """Opens a browser to interactively authenticate a user. :func:`~get_token` opens a browser to a login URL provided by Azure Active Directory and authenticates a user @@ -38,6 +36,11 @@ class InteractiveBrowserCredential(PublicClientCredential): authenticate work or school accounts. :keyword str client_id: Client ID of the Azure Active Directory application users will sign in to. If unspecified, the Azure CLI's ID will be used. + :keyword AuthenticationRecord authentication_record: :class:`AuthenticationRecord` returned by :func:`authenticate` + :keyword bool disable_automatic_authentication: if True, :func:`get_token` will raise + :class:`AuthenticationRequiredError` when user interaction is required to acquire a token. Defaults to False. + :keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache shared by + other user credentials. **This is only supported on Windows.** Defaults to False. :keyword int timeout: seconds to wait for the user to complete authentication. Defaults to 300 (5 minutes). """ @@ -49,42 +52,9 @@ def __init__(self, **kwargs): super(InteractiveBrowserCredential, self).__init__(client_id=client_id, **kwargs) @wrap_exceptions - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (*str, **Any) -> AccessToken - """Request an access token for `scopes`. - - This will open a browser to a login page and listen on localhost for a request indicating authentication has - completed. - - .. note:: This method is called by Azure SDK clients. It isn't intended for use in application code. - - :param str scopes: desired scopes for the access token. This method requires at least one scope. - :rtype: :class:`azure.core.credentials.AccessToken` - :raises ~azure.identity.CredentialUnavailableError: the credential is unable to start an HTTP server on - localhost, or is unable to open a browser - :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` - attribute gives a reason. Any error response from Azure Active Directory is available as the error's - ``response`` attribute. - """ - if not scopes: - raise ValueError("'get_token' requires at least one scope") - - return self._get_token_from_cache(scopes, **kwargs) or self._get_token_by_auth_code(scopes, **kwargs) + def _request_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> dict - def _get_token_from_cache(self, scopes, **kwargs): - """if the user has already signed in, we can redeem a refresh token for a new access token""" - app = self._get_app() - accounts = app.get_accounts() - if accounts: # => user has already authenticated - # MSAL asserts scopes is a list - scopes = list(scopes) # type: ignore - now = int(time.time()) - token = app.acquire_token_silent(scopes, account=accounts[0], **kwargs) - if token and "access_token" in token and "expires_in" in token: - return AccessToken(token["access_token"], now + int(token["expires_in"])) - return None - - def _get_token_by_auth_code(self, scopes, **kwargs): # start an HTTP server on localhost to receive the redirect for port in range(8400, 9000): try: @@ -118,13 +88,8 @@ def _get_token_by_auth_code(self, scopes, **kwargs): # redeem the authorization code for a token code = self._parse_response(request_state, response) - now = int(time.time()) - result = app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri, **kwargs) - - if "access_token" not in result: - raise ClientAuthenticationError(message="Authentication failed: {}".format(result.get("error_description"))) + return app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri, **kwargs) - return AccessToken(result["access_token"], now + int(result["expires_in"])) @staticmethod def _parse_response(request_state, response): diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/user.py b/sdk/identity/azure-identity/azure/identity/_credentials/user.py index cf7e7d7ecaeb..7ff0ed0fbe74 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/user.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/user.py @@ -8,7 +8,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError -from .._internal import PublicClientCredential, wrap_exceptions +from .._internal import InteractiveCredential, PublicClientCredential, wrap_exceptions try: from typing import TYPE_CHECKING @@ -17,18 +17,16 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any, Callable, Optional + from typing import Any, Optional -class DeviceCodeCredential(PublicClientCredential): +class DeviceCodeCredential(InteractiveCredential): """Authenticates users through the device code flow. When :func:`get_token` is called, this credential acquires a verification URL and code from Azure Active Directory. A user must browse to the URL, enter the code, and authenticate with Azure Active Directory. If the user authenticates successfully, the credential receives an access token. - This credential doesn't cache tokens--each :func:`get_token` call begins a new authentication flow. - For more information about the device code flow, see Azure Active Directory documentation: https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code @@ -49,6 +47,11 @@ class DeviceCodeCredential(PublicClientCredential): - ``expires_on`` (datetime.datetime) the UTC time at which the code will expire If this argument isn't provided, the credential will print instructions to stdout. :paramtype prompt_callback: Callable[str, str, ~datetime.datetime] + :keyword AuthenticationRecord authentication_record: :class:`AuthenticationRecord` returned by :func:`authenticate` + :keyword bool disable_automatic_authentication: if True, :func:`get_token` will raise + :class:`AuthenticationRequiredError` when user interaction is required to acquire a token. Defaults to False. + :keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache shared by + other user credentials. **This is only supported on Windows.** Defaults to False. """ def __init__(self, client_id, **kwargs): @@ -58,26 +61,11 @@ def __init__(self, client_id, **kwargs): super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs) @wrap_exceptions - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (*str, **Any) -> AccessToken - """Request an access token for `scopes`. - - This credential won't cache the token. Each call begins a new authentication flow. - - .. note:: This method is called by Azure SDK clients. It isn't intended for use in application code. - - :param str scopes: desired scopes for the access token. This method requires at least one scope. - :rtype: :class:`azure.core.credentials.AccessToken` - :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` - attribute gives a reason. Any error response from Azure Active Directory is available as the error's - ``response`` attribute. - """ - if not scopes: - raise ValueError("'get_token' requires at least one scope") + def _request_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> dict # MSAL requires scopes be a list scopes = list(scopes) # type: ignore - now = int(time.time()) app = self._get_app() flow = app.initiate_device_flow(scopes) @@ -95,7 +83,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if self._timeout is not None and self._timeout < flow["expires_in"]: # user specified an effective timeout we will observe - deadline = now + self._timeout + deadline = int(time.time()) + self._timeout result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline) else: # MSAL will stop polling when the device code expires @@ -108,8 +96,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) raise ClientAuthenticationError(message=message) - token = AccessToken(result["access_token"], now + int(result["expires_in"])) - return token + return result class UsernamePasswordCredential(PublicClientCredential): diff --git a/sdk/identity/azure-identity/azure/identity/_exceptions.py b/sdk/identity/azure-identity/azure/identity/_exceptions.py index 22802306976f..ef1199fdf3b9 100644 --- a/sdk/identity/azure-identity/azure/identity/_exceptions.py +++ b/sdk/identity/azure-identity/azure/identity/_exceptions.py @@ -2,8 +2,37 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from typing import TYPE_CHECKING + from azure.core.exceptions import ClientAuthenticationError +if TYPE_CHECKING: + from typing import Any, Optional, Sequence + class CredentialUnavailableError(ClientAuthenticationError): """The credential did not attempt to authenticate because required data or state is unavailable.""" + + +class AuthenticationRequiredError(CredentialUnavailableError): + """Interactive authentication is required to acquire a token.""" + + def __init__(self, scopes, message=None, error_details=None, **kwargs): + # type: (Sequence[str], Optional[str], Optional[str], **Any) -> None + self._scopes = scopes + self._error_details = error_details + if not message: + message = "Interactive authentication is required to get a token. Call 'authenticate' to begin." + super(AuthenticationRequiredError, self).__init__(message=message, **kwargs) + + @property + def scopes(self): + # type: () -> Sequence[str] + """Scopes requested during the failed authentication""" + return self._scopes + + @property + def error_details(self): + # type: () -> Optional[str] + """Additional authentication error details from Azure Active Directory""" + return self._error_details diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index da4f702d6842..b6d5ec1f2632 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -34,7 +34,7 @@ def get_default_authority(): from .aad_client_base import AadClientBase from .auth_code_redirect_handler import AuthCodeRedirectServer from .exception_wrapper import wrap_exceptions -from .msal_credentials import ConfidentialClientCredential, PublicClientCredential +from .msal_credentials import ConfidentialClientCredential, InteractiveCredential, PublicClientCredential from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse @@ -52,12 +52,16 @@ def _scopes_to_resource(*scopes): __all__ = [ + "_scopes_to_resource", "AadClient", "AadClientBase", "AuthCodeRedirectServer", "ConfidentialClientCredential", + "get_default_authority", + "InteractiveCredential", "MsalTransportAdapter", "MsalTransportResponse", + "normalize_authority", "PublicClientCredential", "wrap_exceptions", ] diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index b7fefdf30237..464e771e00cc 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -6,15 +6,24 @@ This entails monkeypatching MSAL's OAuth client with an adapter substituting an azure-core pipeline for Requests. """ import abc +import base64 +import json +import logging +import os +import sys import time import msal +from six.moves.urllib_parse import urlparse from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter +from .._constants import KnownAuthorities +from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError from .._internal import get_default_authority, normalize_authority +from .._auth_record import AuthenticationRecord try: ABC = abc.ABC @@ -27,23 +36,79 @@ TYPE_CHECKING = False if TYPE_CHECKING: - # pylint:disable=unused-import + # pylint:disable=ungrouped-imports,unused-import from typing import Any, Mapping, Optional, Type, Union +_LOGGER = logging.getLogger(__name__) + +_DEFAULT_AUTHENTICATE_SCOPES = { + "https://" + KnownAuthorities.AZURE_CHINA: ("https://management.core.chinacloudapi.cn//.default",), + "https://" + KnownAuthorities.AZURE_GERMANY: ("https://management.core.cloudapi.de//.default",), + "https://" + KnownAuthorities.AZURE_GOVERNMENT: ("https://management.core.usgovcloudapi.net//.default",), + "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD: ("https://management.core.windows.net//.default",), +} + + +def _decode_client_info(raw): + """Taken from msal.oauth2cli.oidc""" + + raw += "=" * (-len(raw) % 4) + raw = str(raw) # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode. + return base64.urlsafe_b64decode(raw).decode("utf-8") + + +def _build_auth_record(response): + """Build an AuthenticationRecord from the result of an MSAL ClientApplication token request""" + + try: + client_info = json.loads(_decode_client_info(response["client_info"])) + id_token = response["id_token_claims"] + + return AuthenticationRecord( + authority=urlparse(id_token["iss"]).netloc, # "iss" is the URL of the issuing tenant + client_id=id_token["aud"], + home_account_id="{uid}.{utid}".format(**client_info), + tenant_id=id_token["tid"], # tenant which issued the token, not necessarily user's home tenant + username=id_token["preferred_username"], + ) + except (KeyError, ValueError): + # surprising: msal.ClientApplication always requests client_info and an id token, whose shapes shouldn't change + return None + + +def _load_persistent_cache(): + # type: () -> msal.TokenCache + + if sys.platform.startswith("win") and "LOCALAPPDATA" in os.environ: + from msal_extensions.token_cache import WindowsTokenCache + + return WindowsTokenCache( + cache_location=os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", "msal.cache") + ) + + raise NotImplementedError("A persistent cache is not available on this platform.") + + class MsalCredential(ABC): """Base class for credentials wrapping MSAL applications""" def __init__(self, client_id, client_credential=None, **kwargs): # type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None - tenant_id = kwargs.pop("tenant_id", "organizations") authority = kwargs.pop("authority", None) - authority = normalize_authority(authority) if authority else get_default_authority() + self._authority = normalize_authority(authority) if authority else get_default_authority() + self._tenant_id = kwargs.pop("tenant_id", None) or "organizations" - self._base_url = "/".join((authority, tenant_id.strip("/"))) self._client_credential = client_credential self._client_id = client_id + self._cache = kwargs.pop("_cache", None) # internal, for use in tests + if not self._cache: + if kwargs.pop("enable_persistent_cache", False): + self._cache = _load_persistent_cache() + else: + self._cache = msal.TokenCache() + self._adapter = kwargs.pop("msal_adapter", None) or MsalTransportAdapter(**kwargs) # postpone creating the wrapped application because its initializer uses the network @@ -66,7 +131,12 @@ def _create_app(self, cls): # MSAL application initializers use msal.authority to send AAD tenant discovery requests with self._adapter: # MSAL's "authority" is a URL e.g. https://login.microsoftonline.com/common - app = cls(client_id=self._client_id, client_credential=self._client_credential, authority=self._base_url) + app = cls( + client_id=self._client_id, + client_credential=self._client_credential, + authority="{}/{}".format(self._authority, self._tenant_id), + token_cache=self._cache, + ) # monkeypatch the app to replace requests.Session with MsalTransportAdapter app.client.session.close() @@ -116,3 +186,115 @@ def _get_app(self): if not self._msal_app: self._msal_app = self._create_app(msal.PublicClientApplication) return self._msal_app + + +class InteractiveCredential(PublicClientCredential): + def __init__(self, **kwargs): + self._disable_automatic_authentication = kwargs.pop("disable_automatic_authentication", False) + self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] + if self._auth_record: + kwargs.pop("client_id", None) # authentication_record overrides client_id argument + tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id + super(InteractiveCredential, self).__init__( + client_id=self._auth_record.client_id, + authority=self._auth_record.authority, + tenant_id=tenant_id, + **kwargs + ) + else: + super(InteractiveCredential, self).__init__(**kwargs) + + def get_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + """Request an access token for `scopes`. + + .. note:: This method is called by Azure SDK clients. It isn't intended for use in application code. + + :param str scopes: desired scopes for the access token. This method requires at least one scope. + :rtype: :class:`azure.core.credentials.AccessToken` + :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks + required data, state, or platform support + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is + configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. + """ + if not scopes: + raise ValueError("'get_token' requires at least one scope") + + allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) + try: + return self._acquire_token_silent(*scopes, **kwargs) + except AuthenticationRequiredError: + if not allow_prompt: + raise + + # silent authentication failed -> authenticate interactively + now = int(time.time()) + + result = self._request_token(*scopes, **kwargs) + if "access_token" not in result: + message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) + raise ClientAuthenticationError(message=message) + + # this may be the first authentication, or the user may have authenticated a different identity + self._auth_record = _build_auth_record(result) + + return AccessToken(result["access_token"], now + int(result["expires_in"])) + + def authenticate(self, **kwargs): + # type: (**Any) -> AuthenticationRecord + """Interactively authenticate a user. + + :keyword Sequence[str] scopes: scopes to request during authentication, such as those provided by + :func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token + for these scopes. + :rtype: ~azure.identity.AuthenticationRecord + :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` + attribute gives a reason. + """ + + scopes = kwargs.pop("scopes", None) + if not scopes: + if self._authority not in _DEFAULT_AUTHENTICATE_SCOPES: + # the credential is configured to use a cloud whose ARM scope we can't determine + raise CredentialUnavailableError( + message="Authenticating in this environment requires a value for the 'scopes' keyword argument." + ) + + scopes = _DEFAULT_AUTHENTICATE_SCOPES[self._authority] + + _ = self.get_token(*scopes, _allow_prompt=True, **kwargs) + return self.authentication_record # type: ignore + + @property + def authentication_record(self): + # type: () -> Optional[AuthenticationRecord] + """:class:`~azure.identity.AuthenticationRecord` for the most recent authentication""" + return self._auth_record + + @wrap_exceptions + def _acquire_token_silent(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + result = None + if self._auth_record: + app = self._get_app() + for account in app.get_accounts(username=self._auth_record.username): + if account.get("home_account_id") != self._auth_record.home_account_id: + continue + + now = int(time.time()) + result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs) + if result and "access_token" in result and "expires_in" in result: + return AccessToken(result["access_token"], now + int(result["expires_in"])) + + # if we get this far, result is either None or the content of an AAD error response + if result: + details = result.get("error_description") or result.get("error") + raise AuthenticationRequiredError(scopes, error_details=details) + raise AuthenticationRequiredError(scopes) + + @abc.abstractmethod + def _request_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> dict + """Request an access token via a non-silent MSAL token acquisition method, returning that method's result""" diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 7073a6c08e04..f31998fe2833 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -16,12 +16,29 @@ # build_* lifted from msal tests def build_id_token( - iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None, **claims + iss="issuer", + sub="subject", + aud="my_client_id", + username="username", + tenant_id="tenant id", + object_id="object id", + exp=None, + iat=None, + **claims ): # AAD issues "preferred_username", ADFS issues "upn" return "header.%s.signature" % base64.b64encode( json.dumps( dict( - {"iss": iss, "sub": sub, "aud": aud, "exp": exp or (time.time() + 100), "iat": iat or time.time()}, + { + "iss": iss, + "sub": sub, + "aud": aud, + "exp": exp or (time.time() + 100), + "iat": iat or time.time(), + "tid": tenant_id, + "oid": object_id, + "preferred_username": username, + }, **claims ) ).encode() @@ -83,7 +100,7 @@ def add_discrepancy(name, expected, actual): discrepancies.append("{}:\n\t expected: {}\n\t actual: {}".format(name, expected, actual)) if self.base_url and self.base_url != request.url.split("?")[0]: - add_discrepancy('base url', self.base_url, request.url) + add_discrepancy("base url", self.base_url, request.url) if self.url and self.url != request.url: add_discrepancy("url", self.url, request.url) diff --git a/sdk/identity/azure-identity/tests/test_auth_record.py b/sdk/identity/azure-identity/tests/test_auth_record.py new file mode 100644 index 000000000000..5daef9c5dec4 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_auth_record.py @@ -0,0 +1,29 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import json + +from azure.identity import AuthenticationRecord + + +def test_serialization(): + """serialize should accept arbitrary additional key/value pairs, which deserialize should ignore""" + + attrs = ("authority", "client_id","home_account_id", "tenant_id", "username") + nums = (n for n in range(len(attrs))) + record_values = {attr: next(nums) for attr in attrs} + + record = AuthenticationRecord(**record_values) + serialized = record.serialize() + + # AuthenticationRecord's fields should have been serialized + assert json.loads(serialized) == record_values + + deserialized = AuthenticationRecord.deserialize(serialized) + + # the deserialized record and the constructed record should have the same fields + assert sorted(vars(deserialized)) == sorted(vars(record)) + + # the constructed and deserialized records should have the same values + assert all(getattr(deserialized, attr) == record_values[attr] for attr in attrs) diff --git a/sdk/identity/azure-identity/tests/test_device_code_credential.py b/sdk/identity/azure-identity/tests/test_device_code_credential.py index 98a9fe4216ab..2b09a68f58ab 100644 --- a/sdk/identity/azure-identity/tests/test_device_code_credential.py +++ b/sdk/identity/azure-identity/tests/test_device_code_credential.py @@ -6,11 +6,19 @@ from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.identity import DeviceCodeCredential +from azure.identity import AuthenticationRequiredError, DeviceCodeCredential from azure.identity._internal.user_agent import USER_AGENT +from msal import TokenCache import pytest -from helpers import build_aad_response, get_discovery_response, mock_response, Request, validating_transport +from helpers import ( + build_aad_response, + build_id_token, + get_discovery_response, + mock_response, + Request, + validating_transport, +) try: from unittest.mock import Mock @@ -22,10 +30,76 @@ def test_no_scopes(): """The credential should raise when get_token is called with no scopes""" credential = DeviceCodeCredential("client_id") - with pytest.raises(ClientAuthenticationError): + with pytest.raises(ValueError): credential.get_token() +def test_authenticate(): + client_id = "client-id" + environment = "localhost" + issuer = "https://" + environment + tenant_id = "some-tenant" + authority = issuer + "/" + tenant_id + + access_token = "***" + scope = "scope" + + # mock AAD response with id token + object_id = "object-id" + home_tenant = "home-tenant-id" + username = "me@work.com" + id_token = build_id_token(aud=client_id, iss=issuer, object_id=object_id, tenant_id=home_tenant, username=username) + auth_response = build_aad_response( + uid=object_id, utid=home_tenant, access_token=access_token, refresh_token="**", id_token=id_token + ) + + transport = validating_transport( + requests=[Request(url_substring=issuer)] * 4, + responses=[get_discovery_response(authority)] * 2 # instance and tenant discovery + + [ + mock_response( # start device code flow + json_payload={ + "device_code": "_", + "user_code": "user-code", + "verification_uri": "verification-uri", + "expires_in": 42, + } + ), + mock_response(json_payload=dict(auth_response, scope=scope)), # poll for completion + ], + ) + + credential = DeviceCodeCredential( + client_id, + prompt_callback=Mock(), # prevent credential from printing to stdout + transport=transport, + authority=environment, + tenant_id=tenant_id, + _cache=TokenCache(), + ) + record = credential.authenticate(scopes=(scope,)) + + # credential should have a cached access token for the scope used in authenticate + token = credential.get_token(scope) + assert token.token == access_token + + assert record.authority == environment + assert record.home_account_id == object_id + "." + home_tenant + assert record.tenant_id == home_tenant + assert record.username == username + + +def test_disable_automatic_authentication(): + """When configured for strict silent auth, the credential should raise when silent auth fails""" + + empty_cache = TokenCache() # empty cache makes silent auth impossible + transport = Mock(send=Mock(side_effect=Exception("no request should be sent"))) + credential = DeviceCodeCredential("client-id", disable_automatic_authentication=True, transport=transport, _cache=empty_cache) + + with pytest.raises(AuthenticationRequiredError): + credential.get_token("scope") + + def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) @@ -47,7 +121,7 @@ def test_policies_configurable(): ) credential = DeviceCodeCredential( - client_id="client-id", prompt_callback=Mock(), policies=[policy], transport=transport + client_id="client-id", prompt_callback=Mock(), policies=[policy], transport=transport, _cache=TokenCache() ) credential.get_token("scope") @@ -72,7 +146,9 @@ def test_user_agent(): ], ) - credential = DeviceCodeCredential(client_id="client-id", prompt_callback=Mock(), transport=transport) + credential = DeviceCodeCredential( + client_id="client-id", prompt_callback=Mock(), transport=transport, _cache=TokenCache() + ) credential.get_token("scope") @@ -110,7 +186,7 @@ def test_device_code_credential(): callback = Mock() credential = DeviceCodeCredential( - client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False + client_id="_", prompt_callback=callback, transport=transport, instance_discovery=False, _cache=TokenCache() ) now = datetime.datetime.utcnow() @@ -142,7 +218,12 @@ def test_timeout(): ) credential = DeviceCodeCredential( - client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.01, instance_discovery=False + client_id="_", + prompt_callback=Mock(), + transport=transport, + timeout=0.01, + instance_discovery=False, + _cache=TokenCache(), ) with pytest.raises(ClientAuthenticationError) as ex: diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index 4dd245e8973d..5dd7ed4fbfee 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -10,14 +10,21 @@ from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.identity import InteractiveBrowserCredential +from azure.identity import AuthenticationRequiredError, InteractiveBrowserCredential from azure.identity._internal import AuthCodeRedirectServer from azure.identity._internal.user_agent import USER_AGENT - +from msal import TokenCache import pytest from six.moves import urllib, urllib_parse -from helpers import build_aad_response, get_discovery_response, mock_response, Request, validating_transport +from helpers import ( + build_aad_response, + build_id_token, + get_discovery_response, + mock_response, + Request, + validating_transport, +) try: from unittest.mock import Mock, patch @@ -25,13 +32,82 @@ from mock import Mock, patch # type: ignore +WEBBROWSER_OPEN = InteractiveBrowserCredential.__module__ + ".webbrowser.open" + + def test_no_scopes(): """The credential should raise when get_token is called with no scopes""" - with pytest.raises(ClientAuthenticationError): + with pytest.raises(ValueError): InteractiveBrowserCredential().get_token() +def test_authenticate(): + client_id = "client-id" + environment = "localhost" + issuer = "https://" + environment + tenant_id = "some-tenant" + authority = issuer + "/" + tenant_id + + access_token = "***" + scope = "scope" + + # mock AAD response with id token + object_id = "object-id" + home_tenant = "home-tenant-id" + username = "me@work.com" + id_token = build_id_token(aud=client_id, iss=issuer, object_id=object_id, tenant_id=home_tenant, username=username) + auth_response = build_aad_response( + uid=object_id, utid=home_tenant, access_token=access_token, refresh_token="**", id_token=id_token + ) + + transport = validating_transport( + requests=[Request(url_substring=issuer)] * 3, + responses=[get_discovery_response(authority)] * 2 + [mock_response(json_payload=auth_response)], + ) + + # mock local server fakes successful authentication by immediately returning a well-formed response + oauth_state = "state" + auth_code_response = {"code": "authorization-code", "state": [oauth_state]} + server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) + + with patch(InteractiveBrowserCredential.__module__ + ".uuid.uuid4", lambda: oauth_state): + with patch(WEBBROWSER_OPEN, lambda _: True): + credential = InteractiveBrowserCredential( + _cache=TokenCache(), + authority=environment, + client_id=client_id, + server_class=server_class, + tenant_id=tenant_id, + transport=transport, + ) + record = credential.authenticate(scopes=(scope,)) + + # credential should have a cached access token for the scope used in authenticate + with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential should authenticate silently"))): + token = credential.get_token(scope) + assert token.token == access_token + + assert record.authority == environment + assert record.home_account_id == object_id + "." + home_tenant + assert record.tenant_id == home_tenant + assert record.username == username + + +def test_disable_automatic_authentication(): + """When configured for strict silent auth, the credential should raise when silent auth fails""" + + empty_cache = TokenCache() # empty cache makes silent auth impossible + transport = Mock(send=Mock(side_effect=Exception("no request should be sent"))) + credential = InteractiveBrowserCredential( + disable_automatic_authentication=True, transport=transport, _cache=empty_cache + ) + + with patch(WEBBROWSER_OPEN, Mock(side_effect=Exception("credential shouldn't try interactive authentication"))): + with pytest.raises(AuthenticationRequiredError): + credential.get_token("scope") + + @patch("azure.identity._credentials.browser.webbrowser.open", lambda _: True) def test_policies_configurable(): policy = Mock(spec_set=SansIOHTTPPolicy, on_request=Mock()) @@ -46,7 +122,9 @@ def test_policies_configurable(): auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) - credential = InteractiveBrowserCredential(policies=[policy], transport=transport, server_class=server_class) + credential = InteractiveBrowserCredential( + policies=[policy], transport=transport, server_class=server_class, _cache=TokenCache() + ) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): credential.get_token("scope") @@ -66,7 +144,7 @@ def test_user_agent(): auth_code_response = {"code": "authorization-code", "state": [oauth_state]} server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response)) - credential = InteractiveBrowserCredential(transport=transport, server_class=server_class) + credential = InteractiveBrowserCredential(transport=transport, server_class=server_class, _cache=TokenCache()) with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state): credential.get_token("scope") @@ -101,7 +179,8 @@ def test_interactive_credential(mock_open): expires_in=expires_in, refresh_token=expected_refresh_token, uid="uid", - utid="utid", + utid=tenant_id, + id_token=build_id_token(aud=client_id, object_id="uid", tenant_id=tenant_id, iss=endpoint), token_type="Bearer", ) ), @@ -119,11 +198,11 @@ def test_interactive_credential(mock_open): authority=authority, tenant_id=tenant_id, client_id=client_id, - client_secret="secret", server_class=server_class, transport=transport, instance_discovery=False, validate_authority=False, + _cache=TokenCache(), ) # The credential's auth code request includes a uuid which must be included in the redirect. Patching to @@ -171,11 +250,11 @@ def test_interactive_credential_timeout(): credential = InteractiveBrowserCredential( client_id="guid", - client_secret="secret", server_class=server_class, timeout=timeout, transport=transport, instance_discovery=False, # kwargs are passed to MSAL; this one prevents an AAD verification request + _cache=TokenCache(), ) with pytest.raises(ClientAuthenticationError) as ex: @@ -216,7 +295,7 @@ def test_redirect_server(): def test_no_browser(): transport = validating_transport(requests=[Request()] * 2, responses=[get_discovery_response()] * 2) credential = InteractiveBrowserCredential( - client_id="client-id", client_secret="secret", server_class=Mock(), transport=transport + client_id="client-id", server_class=Mock(), transport=transport, _cache=TokenCache() ) with pytest.raises(ClientAuthenticationError, match=r".*browser.*"): credential.get_token("scope") diff --git a/sdk/identity/azure-identity/tests/test_msal_interactive_credential.py b/sdk/identity/azure-identity/tests/test_msal_interactive_credential.py new file mode 100644 index 000000000000..16b05efa643a --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_msal_interactive_credential.py @@ -0,0 +1,227 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.core.exceptions import ClientAuthenticationError +from azure.identity import ( + AuthenticationRequiredError, + AuthenticationRecord, + KnownAuthorities, + CredentialUnavailableError, +) +from azure.identity._internal.msal_credentials import InteractiveCredential +from msal import TokenCache +import pytest + +try: + from unittest.mock import Mock, patch +except ImportError: # python < 3.3 + from mock import Mock, patch # type: ignore + + +class MockCredential(InteractiveCredential): + """Test class to drive InteractiveCredential. + + Default instances have an empty in-memory cache, and raise rather than send an HTTP request. + """ + + def __init__( + self, client_id="...", request_token=None, cache=None, msal_app_factory=None, transport=None, **kwargs + ): + self._msal_app_factory = msal_app_factory + self._request_token_impl = request_token or Mock() + transport = transport or Mock(send=Mock(side_effect=Exception("credential shouldn't send a request"))) + super(MockCredential, self).__init__( + client_id=client_id, _cache=cache or TokenCache(), transport=transport, **kwargs + ) + + def _request_token(self, *scopes, **kwargs): + return self._request_token_impl(*scopes, **kwargs) + + def _get_app(self): + if self._msal_app_factory: + return self._create_app(self._msal_app_factory) + return super(MockCredential, self)._get_app() + + +def test_no_scopes(): + """The credential should raise when get_token is called with no scopes""" + + request_token = Mock(side_effect=Exception("credential shouldn't begin interactive authentication")) + with pytest.raises(ValueError): + MockCredential(request_token=request_token).get_token() + + +def test_authentication_record_argument(): + """The credential should initialize its msal.ClientApplication with values from a given record""" + + record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") + + def validate_app_parameters(authority, client_id, **_): + # the 'authority' argument to msal.ClientApplication should be a URL of the form https://authority/tenant + assert authority == "https://{}/{}".format(record.authority, record.tenant_id) + assert client_id == record.client_id + return Mock(get_accounts=Mock(return_value=[])) + + app_factory = Mock(wraps=validate_app_parameters) + credential = MockCredential( + authentication_record=record, disable_automatic_authentication=True, msal_app_factory=app_factory, + ) + with pytest.raises(AuthenticationRequiredError): + credential.get_token("scope") + + assert app_factory.call_count == 1, "credential didn't create an msal application" + + +def test_tenant_argument_overrides_record(): + """The 'tenant_ic' keyword argument should override a given record's value""" + + tenant_id = "some-guid" + authority = "localhost" + record = AuthenticationRecord(tenant_id, "client-id", authority, "object.tenant", "username") + + expected_tenant = tenant_id[::-1] + expected_authority = "https://{}/{}".format(authority, expected_tenant) + + def validate_authority(authority, **_): + assert authority == expected_authority + return Mock(get_accounts=Mock(return_value=[])) + + credential = MockCredential( + authentication_record=record, + tenant_id=expected_tenant, + disable_automatic_authentication=True, + msal_app_factory=validate_authority, + ) + with pytest.raises(AuthenticationRequiredError): + credential.get_token("scope") + + +def test_disable_automatic_authentication(): + """When silent auth fails the credential should raise, if it's configured not to authenticate automatically""" + + expected_details = "something went wrong" + record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") + msal_app = Mock( + acquire_token_silent_with_error=Mock(return_value={"error_description": expected_details}), + get_accounts=Mock(return_value=[{"home_account_id": record.home_account_id}]), + ) + + credential = MockCredential( + authentication_record=record, + disable_automatic_authentication=True, + msal_app_factory=lambda *_, **__: msal_app, + request_token=Mock(side_effect=Exception("credential shouldn't begin interactive authentication")), + ) + + scope = "scope" + with pytest.raises(AuthenticationRequiredError) as ex: + credential.get_token(scope) + + # the exception should carry the requested scopes and any error message from AAD + assert ex.value.scopes == (scope,) + assert ex.value.error_details == expected_details + + +def test_scopes_round_trip(): + """authenticate should accept the value of AuthenticationRequiredError.scopes""" + + scope = "scope" + + def validate_scopes(*scopes, **_): + assert scopes == (scope,) + return {"access_token": "**", "expires_in": 42} + + request_token = Mock(wraps=validate_scopes) + credential = MockCredential(disable_automatic_authentication=True, request_token=request_token) + with pytest.raises(AuthenticationRequiredError) as ex: + credential.get_token(scope) + + credential.authenticate(scopes=ex.value.scopes) + + assert request_token.call_count == 1, "validation method wasn't called" + + +@pytest.mark.parametrize( + "authority,expected_scope", + ( + (KnownAuthorities.AZURE_CHINA, "https://management.core.chinacloudapi.cn//.default"), + (KnownAuthorities.AZURE_GERMANY, "https://management.core.cloudapi.de//.default"), + (KnownAuthorities.AZURE_GOVERNMENT, "https://management.core.usgovcloudapi.net//.default"), + (KnownAuthorities.AZURE_PUBLIC_CLOUD, "https://management.core.windows.net//.default"), + ), +) +def test_authenticate_default_scopes(authority, expected_scope): + """when given no scopes, authenticate should default to the ARM scope appropriate for the configured authority""" + + def validate_scopes(*scopes): + assert scopes == (expected_scope,) + return {"access_token": "**", "expires_in": 42} + + request_token = Mock(wraps=validate_scopes) + MockCredential(authority=authority, request_token=request_token).authenticate() + assert request_token.call_count == 1 + + +def test_authenticate_unknown_cloud(): + """authenticate should raise when given no scopes in an unknown cloud""" + + with pytest.raises(CredentialUnavailableError): + MockCredential(authority="localhost").authenticate() + + +@pytest.mark.parametrize("option", (True, False)) +def test_authenticate_ignores_disable_automatic_authentication(option): + """authenticate should prompt for authentication regardless of the credential's configuration""" + + request_token = Mock(return_value={"access_token": "**", "expires_in": 42}) + MockCredential(request_token=request_token, disable_automatic_authentication=option).authenticate() + assert request_token.call_count == 1, "credential didn't begin interactive authentication" + + +def test_get_token_wraps_exceptions(): + """get_token shouldn't propagate exceptions from MSAL""" + + class CustomException(Exception): + pass + + expected_message = "something went wrong" + record = AuthenticationRecord("tenant-id", "client-id", "localhost", "object.tenant", "username") + msal_app = Mock( + acquire_token_silent_with_error=Mock(side_effect=CustomException(expected_message)), + get_accounts=Mock(return_value=[{"home_account_id": record.home_account_id}]), + ) + credential = MockCredential(msal_app_factory=lambda *_, **__: msal_app, authentication_record=record) + with pytest.raises(ClientAuthenticationError) as ex: + credential.get_token("scope") + + assert expected_message in ex.value.message + assert msal_app.acquire_token_silent_with_error.call_count == 1, "credential didn't attempt silent auth" + + +def test_enable_persistent_cache(): + """the credential should use the persistent cache only when given enable_persistent_cache=True""" + + class TestCredential(InteractiveCredential): + def _request_token(self, *_, **__): + pass + + expected_cache = Mock() + persistent_cache_loader = InteractiveCredential.__module__ + "._load_persistent_cache" + + # credential should default to an in memory cache + raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache")) + with patch(persistent_cache_loader, raise_when_called): + with patch(InteractiveCredential.__module__ + ".msal.TokenCache", lambda: expected_cache): + credential = TestCredential(client_id="...") + assert credential._cache is expected_cache + + # keyword argument opts in to persistent cache + with patch(persistent_cache_loader, lambda: expected_cache): + credential = TestCredential(client_id="...", enable_persistent_cache=True) + assert credential._cache is expected_cache + + # opting in on an unsupported platform raises an exception + with patch(InteractiveCredential.__module__ + ".sys.platform", "commodore64"): + with pytest.raises(NotImplementedError): + TestCredential(client_id="...", enable_persistent_cache=True) diff --git a/sdk/identity/azure-identity/tests/test_username_password_credential.py b/sdk/identity/azure-identity/tests/test_username_password_credential.py index 16eb0a246e0d..a3e154c89fae 100644 --- a/sdk/identity/azure-identity/tests/test_username_password_credential.py +++ b/sdk/identity/azure-identity/tests/test_username_password_credential.py @@ -3,17 +3,23 @@ # Licensed under the MIT License. # ------------------------------------ from azure.core.exceptions import ClientAuthenticationError -from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy +from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import UsernamePasswordCredential from azure.identity._internal.user_agent import USER_AGENT import pytest -from helpers import build_aad_response, get_discovery_response, mock_response, Request, validating_transport +from helpers import ( + build_aad_response, + get_discovery_response, + mock_response, + Request, + validating_transport, +) try: - from unittest.mock import Mock + from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import Mock, patch # type: ignore def test_no_scopes(): @@ -80,3 +86,15 @@ def test_username_password_credential(): token = credential.get_token("scope") assert token.token == expected_token + + +def test_cache_persistence(): + """The credential should cache only in memory""" + + expected_cache = Mock() + raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache")) + with patch.multiple("msal_extensions.token_cache", WindowsTokenCache=raise_when_called): + with patch("msal.TokenCache", Mock(return_value=expected_cache)): + credential = UsernamePasswordCredential("...", "...", "...") + + assert credential._cache is expected_cache