diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index d753337d91cd..46f0d12cf816 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -1,6 +1,15 @@ # Release History ## 1.4.0b4 (Unreleased) +- `CertificateCredential` and `ClientSecretCredential` can optionally store + access tokens they acquire in a persistent cache. To enable this, construct + the credential with `enable_persistent_cache=True`. On Linux, the persistent + cache requires libsecret and `pygobject`. If these are unavailable or + unusable (e.g. in an SSH session), loading the persistent cache will raise an + error. You may optionally configure the credential to fall back to an + unencrypted cache by constructing it with keyword argument + `allow_unencrypted_cache=True`. + ([#11347](https://github.com/Azure/azure-sdk-for-python/issues/11347)) - `AzureCliCredential` raises `CredentialUnavailableError` when no user is logged in to the Azure CLI. ([#11819](https://github.com/Azure/azure-sdk-for-python/issues/11819)) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py index 56d640c12e37..81adb2621a96 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py @@ -24,6 +24,10 @@ class CertificateCredential(CertificateCredentialBase): :keyword password: The certificate's password. If a unicode string, it will be encoded as UTF-8. If the certificate requires a different encoding, pass appropriately encoded bytes instead. :paramtype password: str or bytes + :keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to + False. + :keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption + is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False. """ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument @@ -41,7 +45,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if not scopes: raise ValueError("'get_token' requires at least one scope") - token = self._client.get_cached_access_token(scopes) + token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) if not token: token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py index 35ec0403114a..4e20c2bd900b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py @@ -25,6 +25,10 @@ class ClientSecretCredential(ClientSecretCredentialBase): :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines authorities for other clouds. + :keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to + False. + :keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption + is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False. """ def get_token(self, *scopes, **kwargs): @@ -42,7 +46,7 @@ def get_token(self, *scopes, **kwargs): if not scopes: raise ValueError("'get_token' requires at least one scope") - token = self._client.get_cached_access_token(scopes) + token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) if not token: token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 61c697337262..522be7424d7d 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -49,9 +49,9 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs): self._client_id = client_id self._pipeline = self._build_pipeline(**kwargs) - def get_cached_access_token(self, scopes): - # type: (Sequence[str]) -> Optional[AccessToken] - tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes)) + def get_cached_access_token(self, scopes, query=None): + # type: (Sequence[str], Optional[dict]) -> Optional[AccessToken] + tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query) for token in tokens: expires_on = int(token["expires_on"]) if expires_on - 300 > int(time.time()): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py index 1b8736c14019..c13fe86d7a29 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/certificate_credential_base.py @@ -4,8 +4,11 @@ # ------------------------------------ import abc +from msal import TokenCache import six -from azure.identity._internal import AadClientCertificate + +from . import AadClientCertificate +from .persistent_cache import load_service_principal_cache try: ABC = abc.ABC @@ -40,7 +43,16 @@ def __init__(self, tenant_id, client_id, certificate_path, **kwargs): pem_bytes = f.read() self._certificate = AadClientCertificate(pem_bytes, password=password) - self._client = self._get_auth_client(tenant_id, client_id, **kwargs) + + enable_persistent_cache = kwargs.pop("enable_persistent_cache", False) + if enable_persistent_cache: + allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False) + cache = load_service_principal_cache(allow_unencrypted) + else: + cache = TokenCache() + + self._client = self._get_auth_client(tenant_id, client_id, cache=cache, **kwargs) + self._client_id = client_id @abc.abstractmethod def _get_auth_client(self, tenant_id, client_id, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/client_secret_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/client_secret_credential_base.py index 977c55dab409..4854a396e84f 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/client_secret_credential_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/client_secret_credential_base.py @@ -5,6 +5,10 @@ import abc from typing import TYPE_CHECKING +from msal import TokenCache + +from .persistent_cache import load_service_principal_cache + try: ABC = abc.ABC except AttributeError: # Python 2.7 @@ -27,7 +31,15 @@ def __init__(self, tenant_id, client_id, client_secret, **kwargs): "tenant_id should be an Azure Active Directory tenant's id (also called its 'directory id')" ) - self._client = self._get_auth_client(tenant_id, client_id, **kwargs) + enable_persistent_cache = kwargs.pop("enable_persistent_cache", False) + if enable_persistent_cache: + allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False) + cache = load_service_principal_cache(allow_unencrypted) + else: + cache = TokenCache() + + self._client = self._get_auth_client(tenant_id, client_id, cache=cache, **kwargs) + self._client_id = client_id self._secret = client_secret @abc.abstractmethod diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index 775b11b79368..b408d37d69ac 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -18,7 +18,7 @@ from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter -from .persistent_cache import load_persistent_cache +from .persistent_cache import load_user_cache from .._constants import KnownAuthorities from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError from .._internal import get_default_authority, normalize_authority @@ -98,7 +98,7 @@ def __init__(self, client_id, client_credential=None, **kwargs): if not self._cache: if kwargs.pop("enable_persistent_cache", False): allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False) - self._cache = load_persistent_cache(allow_unencrypted) + self._cache = load_user_cache(allow_unencrypted) else: self._cache = msal.TokenCache() diff --git a/sdk/identity/azure-identity/azure/identity/_internal/persistent_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/persistent_cache.py index d56be5e7fcdb..4887cd296d7a 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/persistent_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/persistent_cache.py @@ -13,8 +13,18 @@ import msal -def load_persistent_cache(allow_unencrypted): +def load_service_principal_cache(allow_unencrypted): # type: (Optional[bool]) -> msal.TokenCache + return _load_persistent_cache(allow_unencrypted, "MSALConfidentialCache", "msal.confidential.cache") + + +def load_user_cache(allow_unencrypted): + # type: (Optional[bool]) -> msal.TokenCache + return _load_persistent_cache(allow_unencrypted, "MSALCache", "msal.cache") + + +def _load_persistent_cache(allow_unencrypted, account_name, cache_name): + # type: (Optional[bool], str, str) -> msal.TokenCache """Load the persistent cache using msal_extensions. On Windows the cache is a file protected by the Data Protection API. On Linux and macOS the cache is stored by @@ -26,19 +36,21 @@ def load_persistent_cache(allow_unencrypted): """ if sys.platform.startswith("win") and "LOCALAPPDATA" in os.environ: - cache_location = os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", "msal.cache") + cache_location = os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", cache_name) persistence = msal_extensions.FilePersistenceWithDataProtection(cache_location) elif sys.platform.startswith("darwin"): # the cache uses this file's modified timestamp to decide whether to reload - file_path = os.path.expanduser(os.path.join("~", ".IdentityService", "msal.cache")) - persistence = msal_extensions.KeychainPersistence(file_path, "Microsoft.Developer.IdentityService", "MSALCache") + file_path = os.path.expanduser(os.path.join("~", ".IdentityService", cache_name)) + persistence = msal_extensions.KeychainPersistence( + file_path, "Microsoft.Developer.IdentityService", account_name + ) elif sys.platform.startswith("linux"): # The cache uses this file's modified timestamp to decide whether to reload. Note this path is the same # as that of the plaintext fallback: a new encrypted cache will stomp an unencrypted cache. - file_path = os.path.expanduser(os.path.join("~", ".IdentityService", "msal.cache")) + file_path = os.path.expanduser(os.path.join("~", ".IdentityService", cache_name)) try: persistence = msal_extensions.LibsecretPersistence( - file_path, "msal.cache", {"MsalClientID": "Microsoft.Developer.IdentityService"}, label="MSALCache" + file_path, cache_name, {"MsalClientID": "Microsoft.Developer.IdentityService"}, label=account_name ) except ImportError: if not allow_unencrypted: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py index f79a61276ce3..1b044a24c0e1 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py @@ -51,7 +51,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py if not scopes: raise ValueError("'get_token' requires at least one scope") - token = self._client.get_cached_access_token(scopes) + token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) if not token: token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py index 80b5a5577921..87b5472760e6 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py @@ -23,6 +23,10 @@ class ClientSecretCredential(AsyncCredentialBase, ClientSecretCredentialBase): :keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines authorities for other clouds. + :keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to + False. + :keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption + is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False. """ async def __aenter__(self): @@ -48,7 +52,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": if not scopes: raise ValueError("'get_token' requires at least one scope") - token = self._client.get_cached_access_token(scopes) + token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id}) if not token: token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) return token diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index 7a783b4a727b..af0eee63c580 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -13,6 +13,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding +from msal import TokenCache import pytest from six.moves.urllib_parse import urlparse @@ -135,9 +136,107 @@ def validate_jwt(request, client_id, pem_bytes): deserialized_header = json.loads(header.decode("utf-8")) assert deserialized_header["alg"] == "RS256" assert deserialized_header["typ"] == "JWT" - assert urlsafeb64_decode(deserialized_header["x5t"]) == cert.fingerprint(hashes.SHA1()) #nosec + assert urlsafeb64_decode(deserialized_header["x5t"]) == cert.fingerprint(hashes.SHA1()) # nosec assert claims["aud"] == request.url assert claims["iss"] == claims["sub"] == client_id cert.public_key().verify(signature, signed_part.encode("utf-8"), padding.PKCS1v15(), hashes.SHA256()) + + +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_enable_persistent_cache(cert_path, cert_password): + """the credential should use the persistent cache only when given enable_persistent_cache=True""" + + persistent_cache = "azure.identity._internal.persistent_cache" + required_arguments = ("tenant-id", "client-id", cert_path) + + # credential should default to an in memory cache + raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache")) + with patch(persistent_cache + "._load_persistent_cache", raise_when_called): + CertificateCredential(*required_arguments, password=cert_password) + + # allowing an unencrypted cache doesn't count as opting in to the persistent cache + CertificateCredential(*required_arguments, password=cert_password, allow_unencrypted_cache=True) + + # keyword argument opts in to persistent cache + with patch(persistent_cache + ".msal_extensions") as mock_extensions: + CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True) + assert mock_extensions.PersistedTokenCache.call_count == 1 + + # opting in on an unsupported platform raises an exception + with patch(persistent_cache + ".sys.platform", "commodore64"): + with pytest.raises(NotImplementedError): + CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True) + with pytest.raises(NotImplementedError): + CertificateCredential( + *required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True + ) + + +@patch("azure.identity._internal.persistent_cache.sys.platform", "linux2") +@patch("azure.identity._internal.persistent_cache.msal_extensions") +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_persistent_cache_linux(mock_extensions, cert_path, cert_password): + """The credential should use an unencrypted cache when encryption is unavailable and the user explicitly opts in. + + This test was written when Linux was the only platform on which encryption may not be available. + """ + + required_arguments = ("tenant-id", "client-id", cert_path) + + # the credential should prefer an encrypted cache even when the user allows an unencrypted one + CertificateCredential( + *required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True + ) + assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.LibsecretPersistence) + mock_extensions.PersistedTokenCache.reset_mock() + + # (when LibsecretPersistence's dependencies aren't available, constructing it raises ImportError) + mock_extensions.LibsecretPersistence = Mock(side_effect=ImportError) + + # encryption unavailable, no opt in to unencrypted cache -> credential should raise + with pytest.raises(ValueError): + CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True) + + CertificateCredential( + *required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True + ) + assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.FilePersistence) + + +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_persistent_cache_multiple_clients(cert_path, cert_password): + """the credential shouldn't use tokens issued to other service principals""" + + access_token_a = "token a" + access_token_b = "not " + access_token_a + transport_a = validating_transport( + requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))] + ) + transport_b = validating_transport( + requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))] + ) + + cache = TokenCache() + with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader: + mock_cache_loader.return_value = Mock(wraps=cache) + credential_a = CertificateCredential( + "tenant", "client-a", cert_path, password=cert_password, enable_persistent_cache=True, transport=transport_a + ) + assert mock_cache_loader.call_count == 1, "credential should load the persistent cache" + credential_b = CertificateCredential( + "tenant", "client-b", cert_path, password=cert_password, enable_persistent_cache=True, transport=transport_b + ) + assert mock_cache_loader.call_count == 2, "credential should load the persistent cache" + + # A caches a token + scope = "scope" + token_a = credential_a.get_token(scope) + assert token_a.token == access_token_a + assert transport_a.send.call_count == 1 + + # B should get a different token for the same scope + token_b = credential_b.get_token(scope) + assert token_b.token == access_token_b + assert transport_b.send.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index 948c411bddbe..01d2839fc2cc 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -10,6 +10,7 @@ from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import CertificateCredential +from msal import TokenCache import pytest from helpers import build_aad_response, urlsafeb64_decode, mock_response, Request @@ -133,3 +134,102 @@ async def mock_send(request, **kwargs): token = await cred.get_token("scope") assert token.token == access_token + + +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_enable_persistent_cache(cert_path, cert_password): + """the credential should use the persistent cache only when given enable_persistent_cache=True""" + + persistent_cache = "azure.identity._internal.persistent_cache" + required_arguments = ("tenant-id", "client-id", cert_path) + + # credential should default to an in memory cache + raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache")) + with patch(persistent_cache + "._load_persistent_cache", raise_when_called): + CertificateCredential(*required_arguments, password=cert_password) + + # allowing an unencrypted cache doesn't count as opting in to the persistent cache + CertificateCredential(*required_arguments, password=cert_password, allow_unencrypted_cache=True) + + # keyword argument opts in to persistent cache + with patch(persistent_cache + ".msal_extensions") as mock_extensions: + CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True) + assert mock_extensions.PersistedTokenCache.call_count == 1 + + # opting in on an unsupported platform raises an exception + with patch(persistent_cache + ".sys.platform", "commodore64"): + with pytest.raises(NotImplementedError): + CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True) + with pytest.raises(NotImplementedError): + CertificateCredential( + *required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True + ) + + +@patch("azure.identity._internal.persistent_cache.sys.platform", "linux2") +@patch("azure.identity._internal.persistent_cache.msal_extensions") +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_persistent_cache_linux(mock_extensions, cert_path, cert_password): + """The credential should use an unencrypted cache when encryption is unavailable and the user explicitly opts in. + + This test was written when Linux was the only platform on which encryption may not be available. + """ + + required_arguments = ("tenant-id", "client-id", cert_path) + + # the credential should prefer an encrypted cache even when the user allows an unencrypted one + CertificateCredential( + *required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True + ) + assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.LibsecretPersistence) + mock_extensions.PersistedTokenCache.reset_mock() + + # (when LibsecretPersistence's dependencies aren't available, constructing it raises ImportError) + mock_extensions.LibsecretPersistence = Mock(side_effect=ImportError) + + # encryption unavailable, no opt in to unencrypted cache -> credential should raise + with pytest.raises(ValueError): + CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True) + + CertificateCredential( + *required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True + ) + assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.FilePersistence) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +async def test_persistent_cache_multiple_clients(cert_path, cert_password): + """the credential shouldn't use tokens issued to other service principals""" + + access_token_a = "token a" + access_token_b = "not " + access_token_a + transport_a = async_validating_transport( + requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))] + ) + transport_b = async_validating_transport( + requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))] + ) + + cache = TokenCache() + with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader: + mock_cache_loader.return_value = Mock(wraps=cache) + credential_a = CertificateCredential( + "tenant", "client-a", cert_path, password=cert_password, enable_persistent_cache=True, transport=transport_a + ) + assert mock_cache_loader.call_count == 1, "credential should load the persistent cache" + credential_b = CertificateCredential( + "tenant", "client-b", cert_path, password=cert_password, enable_persistent_cache=True, transport=transport_b + ) + assert mock_cache_loader.call_count == 2, "credential should load the persistent cache" + + # A caches a token + scope = "scope" + token_a = await credential_a.get_token(scope) + assert token_a.token == access_token_a + assert transport_a.send.call_count == 1 + + # B should get a different token for the same scope + token_b = await credential_b.get_token(scope) + assert token_b.token == access_token_b + assert transport_b.send.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential.py b/sdk/identity/azure-identity/tests/test_client_secret_credential.py index 732ef67f97c3..ea3362a3f0ff 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential.py @@ -9,6 +9,7 @@ from azure.identity import ClientSecretCredential from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT +from msal import TokenCache import pytest from six.moves.urllib_parse import urlparse @@ -143,3 +144,92 @@ def test_cache(): token = credential.get_token(scope) assert token == valid_token assert mock_send.call_count == 2 + + +def test_enable_persistent_cache(): + """the credential should use the persistent cache only when given enable_persistent_cache=True""" + + required_arguments = ("tenant-id", "client-id", "secret") + persistent_cache = "azure.identity._internal.persistent_cache" + + # credential should default to an in memory cache + raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache")) + with patch(persistent_cache + "._load_persistent_cache", raise_when_called): + ClientSecretCredential(*required_arguments) + + # allowing an unencrypted cache doesn't count as opting in to the persistent cache + ClientSecretCredential(*required_arguments, allow_unencrypted_cache=True) + + # keyword argument opts in to persistent cache + with patch(persistent_cache + ".msal_extensions") as mock_extensions: + ClientSecretCredential(*required_arguments, enable_persistent_cache=True) + assert mock_extensions.PersistedTokenCache.call_count == 1 + + # opting in on an unsupported platform raises an exception + with patch(persistent_cache + ".sys.platform", "commodore64"): + with pytest.raises(NotImplementedError): + ClientSecretCredential(*required_arguments, enable_persistent_cache=True) + with pytest.raises(NotImplementedError): + ClientSecretCredential(*required_arguments, enable_persistent_cache=True, allow_unencrypted_cache=True) + + +@patch("azure.identity._internal.persistent_cache.sys.platform", "linux2") +@patch("azure.identity._internal.persistent_cache.msal_extensions") +def test_persistent_cache_linux(mock_extensions): + """The credential should use an unencrypted cache when encryption is unavailable and the user explicitly opts in. + + This test was written when Linux was the only platform on which encryption may not be available. + """ + + required_arguments = ("tenant-id", "client-id", "secret") + + # the credential should prefer an encrypted cache even when the user allows an unencrypted one + ClientSecretCredential(*required_arguments, enable_persistent_cache=True, allow_unencrypted_cache=True) + assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.LibsecretPersistence) + mock_extensions.PersistedTokenCache.reset_mock() + + # (when LibsecretPersistence's dependencies aren't available, constructing it raises ImportError) + mock_extensions.LibsecretPersistence = Mock(side_effect=ImportError) + + # encryption unavailable, no opt in to unencrypted cache -> credential should raise + with pytest.raises(ValueError): + ClientSecretCredential(*required_arguments, enable_persistent_cache=True) + + ClientSecretCredential(*required_arguments, enable_persistent_cache=True, allow_unencrypted_cache=True) + assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.FilePersistence) + + +def test_persistent_cache_multiple_clients(): + """the credential shouldn't use tokens issued to other service principals""" + + access_token_a = "token a" + access_token_b = "not " + access_token_a + transport_a = validating_transport( + requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))] + ) + transport_b = validating_transport( + requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))] + ) + + cache = TokenCache() + with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader: + mock_cache_loader.return_value = Mock(wraps=cache) + credential_a = ClientSecretCredential( + "tenant-id", "client-a", "...", enable_persistent_cache=True, transport=transport_a + ) + assert mock_cache_loader.call_count == 1, "credential should load the persistent cache" + credential_b = ClientSecretCredential( + "tenant-id", "client-b", "...", enable_persistent_cache=True, transport=transport_b + ) + assert mock_cache_loader.call_count == 2, "credential should load the persistent cache" + + # A caches a token + scope = "scope" + token_a = credential_a.get_token(scope) + assert token_a.token == access_token_a + assert transport_a.send.call_count == 1 + + # B should get a different token for the same scope + token_b = credential_b.get_token(scope) + assert token_b.token == access_token_b + assert transport_b.send.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py index b37ff09dd5bc..4731f1cb7bc2 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import asyncio import time from unittest.mock import Mock, patch from urllib.parse import urlparse @@ -12,11 +11,12 @@ from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import ClientSecretCredential +from msal import TokenCache +import pytest + from helpers import build_aad_response, mock_response, Request from helpers_async import async_validating_transport, AsyncMockTransport, wrap_in_future -import pytest - @pytest.mark.asyncio async def test_no_scopes(): @@ -170,3 +170,93 @@ async def test_cache(): token = await credential.get_token(scope) assert token == valid_token assert mock_send.call_count == 2 + + +def test_enable_persistent_cache(): + """the credential should use the persistent cache only when given enable_persistent_cache=True""" + + required_arguments = ("tenant-id", "client-id", "secret") + persistent_cache = "azure.identity._internal.persistent_cache" + + # credential should default to an in memory cache + raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache")) + with patch(persistent_cache + "._load_persistent_cache", raise_when_called): + ClientSecretCredential(*required_arguments) + + # allowing an unencrypted cache doesn't count as opting in to the persistent cache + ClientSecretCredential(*required_arguments, allow_unencrypted_cache=True) + + # keyword argument opts in to persistent cache + with patch(persistent_cache + ".msal_extensions") as mock_extensions: + ClientSecretCredential(*required_arguments, enable_persistent_cache=True) + assert mock_extensions.PersistedTokenCache.call_count == 1 + + # opting in on an unsupported platform raises an exception + with patch(persistent_cache + ".sys.platform", "commodore64"): + with pytest.raises(NotImplementedError): + ClientSecretCredential(*required_arguments, enable_persistent_cache=True) + with pytest.raises(NotImplementedError): + ClientSecretCredential(*required_arguments, enable_persistent_cache=True, allow_unencrypted_cache=True) + + +@patch("azure.identity._internal.persistent_cache.sys.platform", "linux2") +@patch("azure.identity._internal.persistent_cache.msal_extensions") +def test_persistent_cache_linux(mock_extensions): + """The credential should use an unencrypted cache when encryption is unavailable and the user explicitly opts in. + + This test was written when Linux was the only platform on which encryption may not be available. + """ + + required_arguments = ("tenant-id", "client-id", "secret") + + # the credential should prefer an encrypted cache even when the user allows an unencrypted one + ClientSecretCredential(*required_arguments, enable_persistent_cache=True, allow_unencrypted_cache=True) + assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.LibsecretPersistence) + mock_extensions.PersistedTokenCache.reset_mock() + + # (when LibsecretPersistence's dependencies aren't available, constructing it raises ImportError) + mock_extensions.LibsecretPersistence = Mock(side_effect=ImportError) + + # encryption unavailable, no opt in to unencrypted cache -> credential should raise + with pytest.raises(ValueError): + ClientSecretCredential(*required_arguments, enable_persistent_cache=True) + + ClientSecretCredential(*required_arguments, enable_persistent_cache=True, allow_unencrypted_cache=True) + assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.FilePersistence) + + +@pytest.mark.asyncio +async def test_persistent_cache_multiple_clients(): + """the credential shouldn't use tokens issued to other service principals""" + + access_token_a = "token a" + access_token_b = "not " + access_token_a + transport_a = async_validating_transport( + requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))] + ) + transport_b = async_validating_transport( + requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))] + ) + + cache = TokenCache() + with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader: + mock_cache_loader.return_value = Mock(wraps=cache) + credential_a = ClientSecretCredential( + "tenant-id", "client-a", "...", enable_persistent_cache=True, transport=transport_a + ) + assert mock_cache_loader.call_count == 1, "credential should load the persistent cache" + credential_b = ClientSecretCredential( + "tenant-id", "client-b", "...", enable_persistent_cache=True, transport=transport_b + ) + assert mock_cache_loader.call_count == 2, "credential should load the persistent cache" + + # A caches a token + scope = "scope" + token_a = await credential_a.get_token(scope) + assert token_a.token == access_token_a + assert transport_a.send.call_count == 1 + + # B should get a different token for the same scope + token_b = await credential_b.get_token(scope) + assert token_b.token == access_token_b + assert transport_b.send.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index 0f3efd57d32e..8bfeaac041a4 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -217,7 +217,7 @@ def _request_token(self, *_, **__): # credential should default to an in memory cache raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache")) - with patch(persistent_cache + ".load_persistent_cache", raise_when_called): + with patch(persistent_cache + "._load_persistent_cache", raise_when_called): with patch(InteractiveCredential.__module__ + ".msal.TokenCache", lambda: in_memory_cache): credential = TestCredential() assert credential._cache is in_memory_cache