Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# Release History

## 1.4.0b4 (Unreleased)
- `CertificateCredential` and `ClientSecretCredential` can optionally store
access tokens they acquire in a persistent cache. To enable this, construct
the credential with `enable_persistent_cache=True`. On Linux, the persistent
cache requires libsecret and `pygobject`. If these are unavailable or
unusable (e.g. in an SSH session), loading the persistent cache will raise an
error. You may optionally configure the credential to fall back to an
unencrypted cache by constructing it with keyword argument
`allow_unencrypted_cache=True`.
([#11347](https://github.com/Azure/azure-sdk-for-python/issues/11347))
- `AzureCliCredential` raises `CredentialUnavailableError` when no user is
logged in to the Azure CLI.
([#11819](https://github.com/Azure/azure-sdk-for-python/issues/11819))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class CertificateCredential(CertificateCredentialBase):
:keyword password: The certificate's password. If a unicode string, it will be encoded as UTF-8. If the certificate
requires a different encoding, pass appropriately encoded bytes instead.
:paramtype password: str or bytes
:keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to
False.
:keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption
is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False.
"""

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
Expand All @@ -41,7 +45,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_access_token(scopes)
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
return token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class ClientSecretCredential(ClientSecretCredentialBase):
:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities`
defines authorities for other clouds.
:keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to
False.
:keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption
is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False.
"""

def get_token(self, *scopes, **kwargs):
Expand All @@ -42,7 +46,7 @@ def get_token(self, *scopes, **kwargs):
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_access_token(scopes)
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
return token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
self._client_id = client_id
self._pipeline = self._build_pipeline(**kwargs)

def get_cached_access_token(self, scopes):
# type: (Sequence[str]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes))
def get_cached_access_token(self, scopes, query=None):
# type: (Sequence[str], Optional[dict]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
for token in tokens:
expires_on = int(token["expires_on"])
if expires_on - 300 > int(time.time()):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# ------------------------------------
import abc

from msal import TokenCache
import six
from azure.identity._internal import AadClientCertificate

from . import AadClientCertificate
from .persistent_cache import load_service_principal_cache

try:
ABC = abc.ABC
Expand Down Expand Up @@ -40,7 +43,16 @@ def __init__(self, tenant_id, client_id, certificate_path, **kwargs):
pem_bytes = f.read()

self._certificate = AadClientCertificate(pem_bytes, password=password)
self._client = self._get_auth_client(tenant_id, client_id, **kwargs)

enable_persistent_cache = kwargs.pop("enable_persistent_cache", False)
if enable_persistent_cache:
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
cache = load_service_principal_cache(allow_unencrypted)
else:
cache = TokenCache()

self._client = self._get_auth_client(tenant_id, client_id, cache=cache, **kwargs)
self._client_id = client_id

@abc.abstractmethod
def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import abc
from typing import TYPE_CHECKING

from msal import TokenCache

from .persistent_cache import load_service_principal_cache

try:
ABC = abc.ABC
except AttributeError: # Python 2.7
Expand All @@ -27,7 +31,15 @@ def __init__(self, tenant_id, client_id, client_secret, **kwargs):
"tenant_id should be an Azure Active Directory tenant's id (also called its 'directory id')"
)

self._client = self._get_auth_client(tenant_id, client_id, **kwargs)
enable_persistent_cache = kwargs.pop("enable_persistent_cache", False)
if enable_persistent_cache:
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
cache = load_service_principal_cache(allow_unencrypted)
else:
cache = TokenCache()

self._client = self._get_auth_client(tenant_id, client_id, cache=cache, **kwargs)
self._client_id = client_id
self._secret = client_secret

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .exception_wrapper import wrap_exceptions
from .msal_transport_adapter import MsalTransportAdapter
from .persistent_cache import load_persistent_cache
from .persistent_cache import load_user_cache
from .._constants import KnownAuthorities
from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError
from .._internal import get_default_authority, normalize_authority
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(self, client_id, client_credential=None, **kwargs):
if not self._cache:
if kwargs.pop("enable_persistent_cache", False):
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
self._cache = load_persistent_cache(allow_unencrypted)
self._cache = load_user_cache(allow_unencrypted)
else:
self._cache = msal.TokenCache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,18 @@
import msal


def load_persistent_cache(allow_unencrypted):
def load_service_principal_cache(allow_unencrypted):
# type: (Optional[bool]) -> msal.TokenCache
return _load_persistent_cache(allow_unencrypted, "MSALConfidentialCache", "msal.confidential.cache")


def load_user_cache(allow_unencrypted):
# type: (Optional[bool]) -> msal.TokenCache
return _load_persistent_cache(allow_unencrypted, "MSALCache", "msal.cache")


def _load_persistent_cache(allow_unencrypted, account_name, cache_name):
# type: (Optional[bool], str, str) -> msal.TokenCache
"""Load the persistent cache using msal_extensions.

On Windows the cache is a file protected by the Data Protection API. On Linux and macOS the cache is stored by
Expand All @@ -26,19 +36,21 @@ def load_persistent_cache(allow_unencrypted):
"""

if sys.platform.startswith("win") and "LOCALAPPDATA" in os.environ:
cache_location = os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", "msal.cache")
cache_location = os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", cache_name)
persistence = msal_extensions.FilePersistenceWithDataProtection(cache_location)
elif sys.platform.startswith("darwin"):
# the cache uses this file's modified timestamp to decide whether to reload
file_path = os.path.expanduser(os.path.join("~", ".IdentityService", "msal.cache"))
persistence = msal_extensions.KeychainPersistence(file_path, "Microsoft.Developer.IdentityService", "MSALCache")
file_path = os.path.expanduser(os.path.join("~", ".IdentityService", cache_name))
persistence = msal_extensions.KeychainPersistence(
file_path, "Microsoft.Developer.IdentityService", account_name
)
elif sys.platform.startswith("linux"):
# The cache uses this file's modified timestamp to decide whether to reload. Note this path is the same
# as that of the plaintext fallback: a new encrypted cache will stomp an unencrypted cache.
file_path = os.path.expanduser(os.path.join("~", ".IdentityService", "msal.cache"))
file_path = os.path.expanduser(os.path.join("~", ".IdentityService", cache_name))
try:
persistence = msal_extensions.LibsecretPersistence(
file_path, "msal.cache", {"MsalClientID": "Microsoft.Developer.IdentityService"}, label="MSALCache"
file_path, cache_name, {"MsalClientID": "Microsoft.Developer.IdentityService"}, label=account_name
)
except ImportError:
if not allow_unencrypted:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_access_token(scopes)
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
return token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class ClientSecretCredential(AsyncCredentialBase, ClientSecretCredentialBase):
:keyword str authority: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com',
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities`
defines authorities for other clouds.
:keyword bool enable_persistent_cache: if True, the credential will store tokens in a persistent cache. Defaults to
False.
:keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption
is unavailable. Default to False. Has no effect when `enable_persistent_cache` is False.
"""

async def __aenter__(self):
Expand All @@ -48,7 +52,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_access_token(scopes)
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
return token
Expand Down
101 changes: 100 additions & 1 deletion sdk/identity/azure-identity/tests/test_certificate_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from msal import TokenCache
import pytest
from six.moves.urllib_parse import urlparse

Expand Down Expand Up @@ -135,9 +136,107 @@ def validate_jwt(request, client_id, pem_bytes):
deserialized_header = json.loads(header.decode("utf-8"))
assert deserialized_header["alg"] == "RS256"
assert deserialized_header["typ"] == "JWT"
assert urlsafeb64_decode(deserialized_header["x5t"]) == cert.fingerprint(hashes.SHA1()) #nosec
assert urlsafeb64_decode(deserialized_header["x5t"]) == cert.fingerprint(hashes.SHA1()) # nosec

assert claims["aud"] == request.url
assert claims["iss"] == claims["sub"] == client_id

cert.public_key().verify(signature, signed_part.encode("utf-8"), padding.PKCS1v15(), hashes.SHA256())


@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
def test_enable_persistent_cache(cert_path, cert_password):
"""the credential should use the persistent cache only when given enable_persistent_cache=True"""

persistent_cache = "azure.identity._internal.persistent_cache"
required_arguments = ("tenant-id", "client-id", cert_path)

# credential should default to an in memory cache
raise_when_called = Mock(side_effect=Exception("credential shouldn't attempt to load a persistent cache"))
with patch(persistent_cache + "._load_persistent_cache", raise_when_called):
CertificateCredential(*required_arguments, password=cert_password)

# allowing an unencrypted cache doesn't count as opting in to the persistent cache
CertificateCredential(*required_arguments, password=cert_password, allow_unencrypted_cache=True)

# keyword argument opts in to persistent cache
with patch(persistent_cache + ".msal_extensions") as mock_extensions:
CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True)
assert mock_extensions.PersistedTokenCache.call_count == 1

# opting in on an unsupported platform raises an exception
with patch(persistent_cache + ".sys.platform", "commodore64"):
with pytest.raises(NotImplementedError):
CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True)
with pytest.raises(NotImplementedError):
CertificateCredential(
*required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True
)


@patch("azure.identity._internal.persistent_cache.sys.platform", "linux2")
@patch("azure.identity._internal.persistent_cache.msal_extensions")
@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
def test_persistent_cache_linux(mock_extensions, cert_path, cert_password):
"""The credential should use an unencrypted cache when encryption is unavailable and the user explicitly opts in.

This test was written when Linux was the only platform on which encryption may not be available.
"""

required_arguments = ("tenant-id", "client-id", cert_path)

# the credential should prefer an encrypted cache even when the user allows an unencrypted one
CertificateCredential(
*required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True
)
assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.LibsecretPersistence)
mock_extensions.PersistedTokenCache.reset_mock()

# (when LibsecretPersistence's dependencies aren't available, constructing it raises ImportError)
mock_extensions.LibsecretPersistence = Mock(side_effect=ImportError)

# encryption unavailable, no opt in to unencrypted cache -> credential should raise
with pytest.raises(ValueError):
CertificateCredential(*required_arguments, password=cert_password, enable_persistent_cache=True)

CertificateCredential(
*required_arguments, password=cert_password, enable_persistent_cache=True, allow_unencrypted_cache=True
)
assert mock_extensions.PersistedTokenCache.called_with(mock_extensions.FilePersistence)


@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
def test_persistent_cache_multiple_clients(cert_path, cert_password):
"""the credential shouldn't use tokens issued to other service principals"""

access_token_a = "token a"
access_token_b = "not " + access_token_a
transport_a = validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_a))]
)
transport_b = validating_transport(
requests=[Request()], responses=[mock_response(json_payload=build_aad_response(access_token=access_token_b))]
)

cache = TokenCache()
with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
mock_cache_loader.return_value = Mock(wraps=cache)
credential_a = CertificateCredential(
"tenant", "client-a", cert_path, password=cert_password, enable_persistent_cache=True, transport=transport_a
)
assert mock_cache_loader.call_count == 1, "credential should load the persistent cache"
credential_b = CertificateCredential(
"tenant", "client-b", cert_path, password=cert_password, enable_persistent_cache=True, transport=transport_b
)
assert mock_cache_loader.call_count == 2, "credential should load the persistent cache"

# A caches a token
scope = "scope"
token_a = credential_a.get_token(scope)
assert token_a.token == access_token_a
assert transport_a.send.call_count == 1

# B should get a different token for the same scope
token_b = credential_b.get_token(scope)
assert token_b.token == access_token_b
assert transport_b.send.call_count == 1
Loading