diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 1de29521c222..d3a88265d246 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -14,6 +14,10 @@ future version. ([#10816](https://github.com/Azure/azure-sdk-for-python/issues/10816)) +### Breaking changes +- Removed `authentication_record` keyword argument from the async + `SharedTokenCacheCredential`, i.e. `azure.identity.aio.SharedTokenCacheCredential` + ## 1.4.0 (2020-08-10) ### Added - `DefaultAzureCredential` uses the value of environment variable diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 741dcc30bf03..8ad6aaceb4b9 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -2,10 +2,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import time + +from msal.application import PublicClientApplication + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError + from .. import CredentialUnavailableError from .._constants import AZURE_CLI_CLIENT_ID from .._internal import AadClient -from .._internal.decorators import log_get_token +from .._internal.decorators import log_get_token, wrap_exceptions +from .._internal.msal_client import MsalClient from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase try: @@ -15,7 +23,8 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any + from typing import Any, Optional + from .. import AuthenticationRecord from .._internal import AadClientBase @@ -37,6 +46,20 @@ class SharedTokenCacheCredential(SharedTokenCacheBase): is unavailable. Defaults to False. """ + def __init__(self, username=None, **kwargs): + # type: (Optional[str], **Any) -> None + + self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] + if self._auth_record: + # authenticate in the tenant that produced the record unless "tenant_id" specifies another + self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id + self._cache = kwargs.pop("_cache", None) + self._app = None + self._client_kwargs = kwargs + self._initialized = False + else: + super(SharedTokenCacheCredential, self).__init__(username=username, **kwargs) + @log_get_token("SharedTokenCacheCredential") def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type (*str, **Any) -> AccessToken @@ -51,8 +74,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument :raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user information :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` - attribute gives a reason. Any error response from Azure Active Directory is available as the error's - ``response`` attribute. + attribute gives a reason. """ if not scopes: raise ValueError("'get_token' requires at least one scope") @@ -60,9 +82,12 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if not self._initialized: self._initialize() - if not self._client: + if not self._cache: raise CredentialUnavailableError(message="Shared token cache unavailable") + if self._auth_record: + return self._acquire_token_silent(*scopes) + account = self._get_account(self._username, self._tenant_id) token = self._get_cached_access_token(scopes, account) @@ -79,3 +104,54 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument def _get_auth_client(self, **kwargs): # type: (**Any) -> AadClientBase return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs) + + def _initialize(self): + if self._initialized: + return + + if not self._auth_record: + super(SharedTokenCacheCredential, self)._initialize() + return + + self._load_cache() + if self._cache: + self._app = PublicClientApplication( + client_id=self._auth_record.client_id, + authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id), + token_cache=self._cache, + http_client=MsalClient(**self._client_kwargs), + ) + + self._initialized = True + + @wrap_exceptions + def _acquire_token_silent(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + """Silently acquire a token from MSAL. Requires an AuthenticationRecord.""" + + result = None + + accounts_for_user = self._app.get_accounts(username=self._auth_record.username) + if not accounts_for_user: + raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.") + + for account in accounts_for_user: + if account.get("home_account_id") != self._auth_record.home_account_id: + continue + + now = int(time.time()) + result = self._app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs) + if result and "access_token" in result and "expires_in" in result: + return AccessToken(result["access_token"], now + int(result["expires_in"])) + + # if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently + if result: + # cache contains a matching refresh token but STS returned an error response when MSAL tried to use it + message = "Token acquisition failed" + details = result.get("error_description") or result.get("error") + if details: + message += ": {}".format(details) + raise ClientAuthenticationError(message=message) + + # cache doesn't contain a matching refresh (or access) token + raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username)) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index d28a7602fd5e..11d42936cd57 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -30,7 +30,6 @@ # pylint:disable=unused-import,ungrouped-imports from typing import Any, Iterable, List, Mapping, Optional from .._internal import AadClientBase - from azure.identity import AuthenticationRecord CacheItem = Mapping[str, str] @@ -89,34 +88,29 @@ def _filtered_accounts(accounts, username=None, tenant_id=None): class SharedTokenCacheBase(ABC): def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument # type: (Optional[str], **Any) -> None - - self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] - if self._auth_record: - # authenticate in the tenant that produced the record unless 'tenant_id' specifies another - authenticating_tenant = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id - self._tenant_id = self._auth_record.tenant_id - self._authority = self._auth_record.authority - self._username = self._auth_record.username - self._environment_aliases = frozenset((self._authority,)) - else: - authenticating_tenant = "organizations" - authority = kwargs.pop("authority", None) - self._authority = normalize_authority(authority) if authority else get_default_authority() - environment = urlparse(self._authority).netloc - self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,)) - self._username = username - self._tenant_id = kwargs.pop("tenant_id", None) - + authority = kwargs.pop("authority", None) + self._authority = normalize_authority(authority) if authority else get_default_authority() + environment = urlparse(self._authority).netloc + self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,)) + self._username = username + self._tenant_id = kwargs.pop("tenant_id", None) self._cache = kwargs.pop("_cache", None) self._client = None # type: Optional[AadClientBase] self._client_kwargs = kwargs - self._client_kwargs["tenant_id"] = authenticating_tenant + self._client_kwargs["tenant_id"] = "organizations" self._initialized = False def _initialize(self): if self._initialized: return + self._load_cache() + if self._cache: + self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs) + + self._initialized = True + + def _load_cache(self): if not self._cache and self.supported(): allow_unencrypted = self._client_kwargs.get("allow_unencrypted_cache", False) try: @@ -124,11 +118,6 @@ def _initialize(self): except Exception: # pylint:disable=broad-except pass - if self._cache: - self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs) - - self._initialized = True - @abc.abstractmethod def _get_auth_client(self, **kwargs): # type: (**Any) -> AadClientBase @@ -176,14 +165,6 @@ def _get_account(self, username=None, tenant_id=None): # cache is empty or contains no refresh token -> user needs to sign in raise CredentialUnavailableError(message=NO_ACCOUNTS) - if self._auth_record: - for account in accounts: - if account.get("home_account_id") == self._auth_record.home_account_id: - return account - raise CredentialUnavailableError( - message="The cache contains no account matching the given AuthenticationRecord." - ) - filtered_accounts = _filtered_accounts(accounts, username, tenant_id) if len(filtered_accounts) == 1: return filtered_accounts[0] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index 08d898b15ae5..10ea133f80ce 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -29,8 +29,6 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager): defines authorities for other clouds. :keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains tokens for multiple identities. - :keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as - :class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential` :keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption is unavailable. Defaults to False. """ diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 1bac8aaceedd..e3d6348de4d5 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -170,7 +170,7 @@ def validate_request(request, **_): try: expected_request, response = next(sessions) except StopIteration: - assert False, "unexpected request: {}".format(request) + assert False, "unexpected request: {} {}".format(request.method, request.url) expected_request.assert_matches(request) return response diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index 5d756ecface1..631324bea768 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -28,7 +28,15 @@ except ImportError: # python < 3.3 from mock import Mock, patch # type: ignore -from helpers import build_aad_response, build_id_token, mock_response, Request, validating_transport +from helpers import ( + build_aad_response, + build_id_token, + get_discovery_response, + mock_response, + msal_validating_transport, + Request, + validating_transport, +) def test_supported(): @@ -513,8 +521,13 @@ def test_authority_environment_variable(): def test_authentication_record_empty_cache(): record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username") - transport = Mock(side_effect=Exception("the credential shouldn't send a request")) - credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache()) + + def send(request, **_): + # expecting only MSAL discovery requests + assert request.method == 'GET' + return get_discovery_response() + + credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=TokenCache()) with pytest.raises(CredentialUnavailableError): credential.get_token("scope") @@ -529,13 +542,17 @@ def test_authentication_record_no_match(): username = "me" record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) - transport = Mock(side_effect=Exception("the credential shouldn't send a request")) + def send(request, **_): + # expecting only MSAL discovery requests + assert request.method == 'GET' + return get_discovery_response() + cache = populated_cache( get_account_event( "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, ), ) - credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) + credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=cache) with pytest.raises(CredentialUnavailableError): credential.get_token("scope") @@ -557,7 +574,8 @@ def test_authentication_record(): ) cache = populated_cache(account) - transport = validating_transport( + transport = msal_validating_transport( + endpoint="https://{}/{}".format(authority, tenant_id), requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) @@ -593,7 +611,8 @@ def test_auth_record_multiple_accounts_for_username(): ), ) - transport = validating_transport( + transport = msal_validating_transport( + endpoint="https://{}/{}".format(authority, tenant_id), requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], ) @@ -741,19 +760,22 @@ def test_authentication_record_authenticating_tenant(): """when given a record and 'tenant_id', the credential should authenticate in the latter""" expected_tenant_id = "tenant-id" - record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") + record = AuthenticationRecord("not- " + expected_tenant_id, "...", "localhost", "...", "...") - with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: - credential = SharedTokenCacheCredential( - authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id - ) - with pytest.raises(CredentialUnavailableError): - # this raises because the cache is empty - credential.get_token("scope") + def mock_send(request, **_): + if not request.body: + return get_discovery_response() + assert request.url.startswith("https://localhost/" + expected_tenant_id) + return mock_response(json_payload=build_aad_response(access_token="*")) + + transport = Mock(send=Mock(wraps=mock_send)) + credential = SharedTokenCacheCredential( + authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id, transport=transport + ) + with pytest.raises(CredentialUnavailableError): + credential.get_token("scope") # this raises because the cache is empty - assert get_auth_client.call_count == 1 - _, kwargs = get_auth_client.call_args - assert kwargs["tenant_id"] == expected_tenant_id + assert transport.send.called def get_account_event( diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index 389ba606d482..7613200e97ee 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -7,7 +7,7 @@ from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.identity import AuthenticationRecord, CredentialUnavailableError +from azure.identity import CredentialUnavailableError from azure.identity.aio import SharedTokenCacheCredential from azure.identity._constants import EnvironmentVariables from azure.identity._internal.shared_token_cache import ( @@ -589,122 +589,6 @@ async def test_authority_environment_variable(): assert token.token == expected_access_token -@pytest.mark.asyncio -async def test_authentication_record_empty_cache(): - record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username") - transport = Mock(side_effect=Exception("the credential shouldn't send a request")) - credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache()) - - with pytest.raises(CredentialUnavailableError): - await credential.get_token("scope") - - -@pytest.mark.asyncio -async def test_authentication_record_no_match(): - tenant_id = "tenant-id" - client_id = "client-id" - authority = "localhost" - object_id = "object-id" - home_account_id = object_id + "." + tenant_id - username = "me" - record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) - - transport = Mock(side_effect=Exception("the credential shouldn't send a request")) - cache = populated_cache( - get_account_event( - "not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id, - ), - ) - credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) - - with pytest.raises(CredentialUnavailableError): - await credential.get_token("scope") - - -@pytest.mark.asyncio -async def test_authentication_record(): - tenant_id = "tenant-id" - client_id = "client-id" - authority = "localhost" - object_id = "object-id" - home_account_id = object_id + "." + tenant_id - username = "me" - record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) - - expected_access_token = "****" - expected_refresh_token = "**" - account = get_account_event( - username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token - ) - cache = populated_cache(account) - - transport = async_validating_transport( - requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], - responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], - ) - credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) - - token = await credential.get_token("scope") - assert token.token == expected_access_token - - -@pytest.mark.asyncio -async def test_auth_record_multiple_accounts_for_username(): - tenant_id = "tenant-id" - client_id = "client-id" - authority = "localhost" - object_id = "object-id" - home_account_id = object_id + "." + tenant_id - username = "me" - record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username) - - expected_access_token = "****" - expected_refresh_token = "**" - expected_account = get_account_event( - username, object_id, tenant_id, authority=authority, client_id=client_id, refresh_token=expected_refresh_token - ) - cache = populated_cache( - expected_account, - get_account_event( # this account matches all but the record's tenant - username, - object_id, - "different-" + tenant_id, - authority=authority, - client_id=client_id, - refresh_token="not-" + expected_refresh_token, - ), - ) - - transport = async_validating_transport( - requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})], - responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))], - ) - credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache) - - token = await credential.get_token("scope") - assert token.token == expected_access_token - - -@pytest.mark.asyncio -async def test_authentication_record_authenticating_tenant(): - """when given a record and 'tenant_id', the credential should authenticate in the latter""" - - expected_tenant_id = "tenant-id" - record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") - - with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: - credential = SharedTokenCacheCredential( - authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id - ) - with pytest.raises(CredentialUnavailableError): - # this raises because the cache is empty - await credential.get_token("scope") - - assert get_auth_client.call_count == 1 - _, kwargs = get_auth_client.call_args - assert kwargs["tenant_id"] == expected_tenant_id - - @pytest.mark.asyncio async def test_allow_unencrypted_cache(): """The credential should use an unencrypted cache when encryption is unavailable and the user explicitly allows it.