diff --git a/sdk/core/azure-core/azure/core/credentials.py b/sdk/core/azure-core/azure/core/credentials.py index e5146d7f947a..f14c1eef4acb 100644 --- a/sdk/core/azure-core/azure/core/credentials.py +++ b/sdk/core/azure-core/azure/core/credentials.py @@ -24,6 +24,13 @@ def get_token(self, *scopes, **kwargs): # type: (*str, **Any) -> AccessToken pass + def supports_caching(self): + # type: () -> bool + """Whether this TokenCredential maintains its own token cache. + + An authentication policy may call this before deciding whether to establish its own cache. + """ + else: AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) diff --git a/sdk/core/azure-core/azure/core/credentials_async.py b/sdk/core/azure-core/azure/core/credentials_async.py index b42e717f4790..f16802b09e1f 100644 --- a/sdk/core/azure-core/azure/core/credentials_async.py +++ b/sdk/core/azure-core/azure/core/credentials_async.py @@ -13,6 +13,13 @@ class AsyncTokenCredential(Protocol): async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: pass + def supports_caching(self): + # type: () -> bool + """Whether this TokenCredential maintains its own token cache. + + An authentication policy may call this before deciding whether to establish its own cache. + """ + async def close(self) -> None: pass diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 228e3fd20f58..6e096f55da67 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -35,6 +35,7 @@ def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argu super(_BearerTokenCredentialPolicyBase, self).__init__() self._scopes = scopes self._credential = credential + self._credential_supports_caching = getattr(self._credential, "supports_caching", lambda: False)() self._token = None # type: Optional[AccessToken] @staticmethod @@ -68,7 +69,7 @@ def _update_headers(headers, token): @property def _need_new_token(self): # type: () -> bool - return not self._token or self._token.expires_on - time.time() < 300 + return self._credential_supports_caching or (not self._token or self._token.expires_on - time.time() < 300) class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy): diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 76564320b742..111673429edb 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -31,6 +31,7 @@ def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: " # pylint:disable=unused-argument super().__init__() self._credential = credential + self._credential_supports_caching = getattr(self._credential, "supports_caching", lambda: False)() self._lock = asyncio.Lock() self._scopes = scopes self._token = None # type: Optional[AccessToken] @@ -129,4 +130,5 @@ def on_exception(self, request: "PipelineRequest") -> "Union[bool, Awaitable[boo return False def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + return self._credential_supports_caching or (not self._token or self._token.expires_on - time.time() < 300) + diff --git a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py index 7230018aa37f..ca1f2356e9cc 100644 --- a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py @@ -33,7 +33,7 @@ async def get_token(_): get_token_calls += 1 return expected_token - fake_credential = Mock(get_token=get_token) + fake_credential = Mock(get_token=get_token, supports_caching=lambda: False) policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = AsyncPipeline(transport=Mock(), policies=policies) @@ -71,7 +71,7 @@ async def get_token(_): get_token_calls += 1 return expected_token - credential = Mock(get_token=get_token) + credential = Mock(get_token=get_token, supports_caching=lambda: False) policies = [ AsyncBearerTokenCredentialPolicy(credential, "scope"), Mock(send=Mock(return_value=get_completed_future(Mock()))), @@ -100,6 +100,22 @@ async def get_token(_): assert get_token_calls == 2 # token expired -> policy should call get_token +async def test_bearer_policy_credential_supports_caching(): + """BearerTokenCredentialPolicy should not cache tokens when its credential claims to do so""" + + token = AccessToken("token", int(time.time()) + 3600) + credential = Mock(get_token=Mock(return_value=get_completed_future(token)), supports_caching=lambda: True) + pipeline = AsyncPipeline( + transport=Mock(send=lambda _: get_completed_future(Mock())), + policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")], + ) + + request_count = 6 + for _ in range(request_count): + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert credential.get_token.call_count == request_count + + async def test_bearer_policy_optionally_enforces_https(): """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" @@ -202,6 +218,7 @@ async def fake_send(*args, **kwargs): fake_send.calls = 1 return Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) raise TestException() + fake_send.calls = 0 policy = TestPolicy(credential, "scope") diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index de029e8ea352..6b9b2431f7d7 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -35,7 +35,7 @@ def verify_authorization_header(request): assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) return Mock() - fake_credential = Mock(get_token=Mock(return_value=expected_token)) + fake_credential = Mock(get_token=Mock(return_value=expected_token), supports_caching=lambda: False) policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = Pipeline(transport=Mock(), policies=policies) @@ -67,7 +67,7 @@ def verify_request(request): def test_bearer_policy_token_caching(): good_for_one_hour = AccessToken("token", time.time() + 3600) - credential = Mock(get_token=Mock(return_value=good_for_one_hour)) + credential = Mock(get_token=Mock(return_value=good_for_one_hour), supports_caching=lambda: False) pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) pipeline.run(HttpRequest("GET", "https://spam.eggs")) @@ -88,6 +88,19 @@ def test_bearer_policy_token_caching(): assert credential.get_token.call_count == 2 # token expired -> policy should call get_token +def test_bearer_policy_credential_supports_caching(): + """BearerTokenCredentialPolicy should not cache tokens when its credential claims to do so""" + + token = AccessToken("token", int(time.time()) + 3600) + credential = Mock(get_token=Mock(return_value=token), supports_caching=lambda: True) + pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) + + request_count = 6 + for _ in range(request_count): + pipeline.run(HttpRequest("GET", "https://localhost")) + assert credential.get_token.call_count == request_count + + def test_bearer_policy_optionally_enforces_https(): """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" @@ -258,7 +271,7 @@ def test_key_vault_regression(): from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase - credential = Mock() + credential = Mock(supports_caching=lambda: False) policy = _BearerTokenCredentialPolicyBase(credential) assert policy._credential is credential