Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
pass

def supports_caching(self):
# type: () -> bool
"""Whether this TokenCredential maintains its own token cache.

An authentication policy may call this before deciding whether to establish its own cache.
"""


else:
AccessToken = namedtuple("AccessToken", ["token", "expires_on"])
Expand Down
7 changes: 7 additions & 0 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ class AsyncTokenCredential(Protocol):
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
pass

def supports_caching(self):
# type: () -> bool
"""Whether this TokenCredential maintains its own token cache.

An authentication policy may call this before deciding whether to establish its own cache.
"""

async def close(self) -> None:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argu
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._credential_supports_caching = getattr(self._credential, "supports_caching", lambda: False)()
self._token = None # type: Optional[AccessToken]

@staticmethod
Expand Down Expand Up @@ -68,7 +69,7 @@ def _update_headers(headers, token):
@property
def _need_new_token(self):
# type: () -> bool
return not self._token or self._token.expires_on - time.time() < 300
return self._credential_supports_caching or (not self._token or self._token.expires_on - time.time() < 300)


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "
# pylint:disable=unused-argument
super().__init__()
self._credential = credential
self._credential_supports_caching = getattr(self._credential, "supports_caching", lambda: False)()
self._lock = asyncio.Lock()
self._scopes = scopes
self._token = None # type: Optional[AccessToken]
Expand Down Expand Up @@ -129,4 +130,5 @@ def on_exception(self, request: "PipelineRequest") -> "Union[bool, Awaitable[boo
return False

def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
return self._credential_supports_caching or (not self._token or self._token.expires_on - time.time() < 300)

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def get_token(_):
get_token_calls += 1
return expected_token

fake_credential = Mock(get_token=get_token)
fake_credential = Mock(get_token=get_token, supports_caching=lambda: False)
policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]
pipeline = AsyncPipeline(transport=Mock(), policies=policies)

Expand Down Expand Up @@ -71,7 +71,7 @@ async def get_token(_):
get_token_calls += 1
return expected_token

credential = Mock(get_token=get_token)
credential = Mock(get_token=get_token, supports_caching=lambda: False)
policies = [
AsyncBearerTokenCredentialPolicy(credential, "scope"),
Mock(send=Mock(return_value=get_completed_future(Mock()))),
Expand Down Expand Up @@ -100,6 +100,22 @@ async def get_token(_):
assert get_token_calls == 2 # token expired -> policy should call get_token


async def test_bearer_policy_credential_supports_caching():
"""BearerTokenCredentialPolicy should not cache tokens when its credential claims to do so"""

token = AccessToken("token", int(time.time()) + 3600)
credential = Mock(get_token=Mock(return_value=get_completed_future(token)), supports_caching=lambda: True)
pipeline = AsyncPipeline(
transport=Mock(send=lambda _: get_completed_future(Mock())),
policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")],
)

request_count = 6
for _ in range(request_count):
await pipeline.run(HttpRequest("GET", "https://localhost"))
assert credential.get_token.call_count == request_count


async def test_bearer_policy_optionally_enforces_https():
"""HTTPS enforcement should be controlled by a keyword argument, and enabled by default"""

Expand Down Expand Up @@ -202,6 +218,7 @@ async def fake_send(*args, **kwargs):
fake_send.calls = 1
return Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'})
raise TestException()

fake_send.calls = 0

policy = TestPolicy(credential, "scope")
Expand Down
19 changes: 16 additions & 3 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def verify_authorization_header(request):
assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token)
return Mock()

fake_credential = Mock(get_token=Mock(return_value=expected_token))
fake_credential = Mock(get_token=Mock(return_value=expected_token), supports_caching=lambda: False)
policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]

pipeline = Pipeline(transport=Mock(), policies=policies)
Expand Down Expand Up @@ -67,7 +67,7 @@ def verify_request(request):

def test_bearer_policy_token_caching():
good_for_one_hour = AccessToken("token", time.time() + 3600)
credential = Mock(get_token=Mock(return_value=good_for_one_hour))
credential = Mock(get_token=Mock(return_value=good_for_one_hour), supports_caching=lambda: False)
pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")])

pipeline.run(HttpRequest("GET", "https://spam.eggs"))
Expand All @@ -88,6 +88,19 @@ def test_bearer_policy_token_caching():
assert credential.get_token.call_count == 2 # token expired -> policy should call get_token


def test_bearer_policy_credential_supports_caching():
"""BearerTokenCredentialPolicy should not cache tokens when its credential claims to do so"""

token = AccessToken("token", int(time.time()) + 3600)
credential = Mock(get_token=Mock(return_value=token), supports_caching=lambda: True)
pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")])

request_count = 6
for _ in range(request_count):
pipeline.run(HttpRequest("GET", "https://localhost"))
assert credential.get_token.call_count == request_count


def test_bearer_policy_optionally_enforces_https():
"""HTTPS enforcement should be controlled by a keyword argument, and enabled by default"""

Expand Down Expand Up @@ -258,7 +271,7 @@ def test_key_vault_regression():

from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase

credential = Mock()
credential = Mock(supports_caching=lambda: False)
policy = _BearerTokenCredentialPolicyBase(credential)
assert policy._credential is credential

Expand Down