From 992ac795ac2cf7a17a8464cb34e3309391c10c96 Mon Sep 17 00:00:00 2001 From: msyyc <70930885+msyyc@users.noreply.github.com> Date: Mon, 21 Jun 2021 16:41:34 +0800 Subject: [PATCH] cae --- sdk/core/azure-core/CHANGELOG.md | 3 + .../azure/core/pipeline/policies/__init__.py | 11 +- .../core/pipeline/policies/_authentication.py | 80 ++++++ .../policies/_authentication_async.py | 104 ++++++++ .../test_challenge_authentication_async.py | 231 ++++++++++++++++++ .../tests/test_challenge_authentication.py | 194 +++++++++++++++ 6 files changed, 621 insertions(+), 2 deletions(-) create mode 100644 sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py create mode 100644 sdk/core/azure-core/tests/test_challenge_authentication.py diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 10e242697b12..9db2aac10217 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -2,6 +2,9 @@ ## 1.15.1 (Unreleased) +### New Features + +- Added `azure.core.pipeline.policies.ChallengeAuthenticationPolicy` and `.AsyncChallengeAuthenticationPolicy` ## 1.15.0 (2021-06-04) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py b/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py index a0e81b13cef5..0f35696bdbf8 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py @@ -25,7 +25,12 @@ # -------------------------------------------------------------------------- from ._base import HTTPPolicy, SansIOHTTPPolicy, RequestHistory -from ._authentication import BearerTokenCredentialPolicy, AzureKeyCredentialPolicy, AzureSasCredentialPolicy +from ._authentication import ( + BearerTokenCredentialPolicy, + AzureKeyCredentialPolicy, + AzureSasCredentialPolicy, + ChallengeAuthenticationPolicy, +) from ._custom_hook import CustomHookPolicy from ._redirect import RedirectPolicy from ._retry import RetryPolicy, RetryMode @@ -44,6 +49,7 @@ 'HTTPPolicy', 'SansIOHTTPPolicy', 'BearerTokenCredentialPolicy', + 'ChallengeAuthenticationPolicy', 'AzureKeyCredentialPolicy', 'AzureSasCredentialPolicy', 'HeadersPolicy', @@ -65,12 +71,13 @@ try: from ._base_async import AsyncHTTPPolicy - from ._authentication_async import AsyncBearerTokenCredentialPolicy + from ._authentication_async import AsyncBearerTokenCredentialPolicy, AsyncChallengeAuthenticationPolicy from ._redirect_async import AsyncRedirectPolicy from ._retry_async import AsyncRetryPolicy __all__.extend([ 'AsyncHTTPPolicy', 'AsyncBearerTokenCredentialPolicy', + 'AsyncChallengeAuthenticationPolicy', 'AsyncRedirectPolicy', 'AsyncRetryPolicy' ]) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 251dc8a610bf..88c346a8df19 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -40,6 +40,8 @@ def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argu @staticmethod def _enforce_https(request): # type: (PipelineRequest) -> None + """Raise ServiceRequestError if the request URL is non-HTTPS and the sender did not specify "enforce_https=False" + """ # move 'enforce_https' from options to context so it persists # across retries but isn't passed to a transport implementation @@ -171,6 +173,84 @@ def on_exception(self, request): return False +class ChallengeAuthenticationPolicy(HTTPPolicy): + """Base class for policies that authorize requests with bearer tokens and expect authentication challenges + + :param ~azure.core.credentials.TokenCredential credential: an object which can provide access tokens, such as a + credential from :mod:`azure.identity` + :param str scopes: required authentication scopes + """ + + def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (TokenCredential, *str, **Any) -> None + super(ChallengeAuthenticationPolicy, self).__init__() + self._scopes = scopes + self._credential = credential + self._token = None # type: Optional[AccessToken] + + def _need_new_token(self): + # type: () -> bool + return not self._token or self._token.expires_on - time.time() < 300 + + def authorize_request(self, request, *scopes, **kwargs): + # type: (PipelineRequest, *str, **Any) -> None + """Acquire a token from the credential and authorize the request with it. + + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to + authorize future requests. + + :param ~azure.core.pipeline.PipelineRequest request: the request + :param str scopes: required scopes of authentication + """ + self._token = self._credential.get_token(*scopes, **kwargs) + request.http_request.headers["Authorization"] = "Bearer " + self._token.token + + def on_request(self, request): + # type: (PipelineRequest) -> None + """Called before the policy sends a request. + + The base implementation authorizes the request with a bearer token. + + :param ~azure.core.pipeline.PipelineRequest request: the request + """ + + if self._token is None or self._need_new_token(): + self._token = self._credential.get_token(*self._scopes) + request.http_request.headers["Authorization"] = "Bearer " + self._token.token + + def send(self, request): + # type: (PipelineRequest) -> PipelineResponse + """Authorizes a request with a bearer token, possibly handling an authentication challenge + + :param ~azure.core.pipeline.PipelineRequest request: the request + """ + _BearerTokenCredentialPolicyBase._enforce_https(request) + + self.on_request(request) + + response = self.next.send(request) + + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers and self.on_challenge(request, response): + response = self.next.send(request) + + return response + + def on_challenge(self, request, response): + # type: (PipelineRequest, PipelineResponse) -> bool + """Authorize request according to an authentication challenge + + This method is called when the resource provider responds 401 with a WWW-Authenticate header. + + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response + :returns: a bool indicating whether the policy should send the request + """ + # pylint:disable=unused-argument,no-self-use + return False + + class AzureKeyCredentialPolicy(SansIOHTTPPolicy): """Adds a key header for the provided credential. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 479ef9057571..e4e6ec1fa49a 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -125,3 +125,107 @@ def on_exception(self, request: "PipelineRequest") -> "Union[bool, Awaitable[boo def _need_new_token(self) -> bool: return not self._token or self._token.expires_on - time.time() < 300 + + +class AsyncChallengeAuthenticationPolicy(AsyncHTTPPolicy): + """Base class for policies that authorize requests with bearer tokens and expect authentication challenges + + :param ~azure.core.credentials.AsyncTokenCredential credential: an object which can asynchronously provide access + tokens, such as a credential from :mod:`azure.identity.aio` + :param str scopes: required authentication scopes + """ + + def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None: + # pylint:disable=unused-argument + super().__init__() + self._credential = credential + self._lock = asyncio.Lock() + self._scopes = scopes + self._token = None # type: Optional[AccessToken] + + def _need_new_token(self) -> bool: + return not self._token or self._token.expires_on - time.time() < 300 + + async def on_request(self, request: "PipelineRequest") -> None: + """Called before the policy sends a request. + + The base implementation authorizes the request with a bearer token. + + :param ~azure.core.pipeline.PipelineRequest request: the request + """ + + if self._token is None or self._need_new_token(): + async with self._lock: + # double check because another coroutine may have acquired a token while we waited to acquire the lock + if not self._token or self._need_new_token(): + self._token = await self._credential.get_token(*self._scopes) + + request.http_request.headers["Authorization"] = "Bearer " + self._token.token + + async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None: + """Acquire a token from the credential and authorize the request with it. + + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to + authorize future requests. + + :param ~azure.core.pipeline.PipelineRequest request: the request + :param str scopes: required scopes of authentication + """ + + async with self._lock: + self._token = await self._credential.get_token(*scopes, **kwargs) + request.http_request.headers["Authorization"] = "Bearer " + self._token.token + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + """Authorizes a request with a bearer token, possibly handling an authentication challenge + + :param ~azure.core.pipeline.PipelineRequest request: The request + """ + _BearerTokenCredentialPolicyBase._enforce_https(request) + + await self.on_request(request) + + response = await self.next.send(request) + + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) + if request_authorized: + response = await self.next.send(request) + + return response + + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + """Authorize request according to an authentication challenge + + This method is called when the resource provider responds 401 with a WWW-Authenticate header. + + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response + :returns: a bool indicating whether the policy should send the request + """ + # pylint:disable=unused-argument,no-self-use + return False + + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> "Union[None, Awaitable[None]]": + """Executed after the request comes back from the next policy. + + :param request: Request to be modified after returning from the policy. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: Pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + """ + + def on_exception(self, request: "PipelineRequest") -> "Union[bool, Awaitable[bool]]": + """Executed when an exception is raised while executing the next policy. + + This method is executed inside the exception handler. + + :param request: The Pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: False by default, override with True to stop the exception. + :rtype: bool + """ + # pylint: disable=no-self-use,unused-argument + return False \ No newline at end of file diff --git a/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py new file mode 100644 index 000000000000..827476dc868b --- /dev/null +++ b/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py @@ -0,0 +1,231 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import asyncio +import time +from unittest.mock import Mock + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ServiceRequestError +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.policies import AsyncChallengeAuthenticationPolicy, SansIOHTTPPolicy +from azure.core.pipeline.transport import HttpRequest + +import pytest + +pytestmark = pytest.mark.asyncio + + +async def test_adds_header(): + """The bearer token policy should add a header containing a token from its credential""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + expected_token = AccessToken("expected_token", 2524608000) + + async def verify_authorization_header(request): + assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) + return Mock() + + get_token_calls = 0 + + async def get_token(_): + nonlocal get_token_calls + get_token_calls += 1 + return expected_token + + fake_credential = Mock(get_token=get_token) + policies = [AsyncChallengeAuthenticationPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] + pipeline = AsyncPipeline(transport=Mock(), policies=policies) + + await pipeline.run(HttpRequest("GET", "https://localhost"), context=None) + assert get_token_calls == 1 + + await pipeline.run(HttpRequest("GET", "https://localhost"), context=None) + # Didn't need a new token + assert get_token_calls == 1 + + +async def test_default_context(): + """The policy should call get_token with the scopes given at construction, and no keyword arguments, by default""" + + async def send(_): + return Mock() + + async def get_token(_): + return AccessToken("", 0) + + expected_scope = "scope" + credential = Mock(get_token=Mock(wraps=get_token)) + policy = AsyncChallengeAuthenticationPolicy(credential, expected_scope) + pipeline = AsyncPipeline(transport=Mock(send=send), policies=[policy]) + + await pipeline.run(HttpRequest("GET", "https://localhost")) + + credential.get_token.assert_called_once_with(expected_scope) + + +async def test_send(): + """The bearer token policy should invoke the next policy's send method and return the result""" + expected_request = HttpRequest("GET", "https://localhost") + expected_response = Mock() + + async def verify_request(request): + assert request.http_request is expected_request + return expected_response + + fake_credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("", 0))) + policies = [AsyncChallengeAuthenticationPolicy(fake_credential, "scope"), Mock(send=verify_request)] + response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request) + + assert response is expected_response + + +async def test_token_caching(): + good_for_one_hour = AccessToken("token", time.time() + 3600) + expected_token = good_for_one_hour + get_token_calls = 0 + + async def get_token(_): + nonlocal get_token_calls + get_token_calls += 1 + return expected_token + + credential = Mock(get_token=get_token) + + async def send(_): + return Mock() + + transport = Mock(send=send) + + pipeline = AsyncPipeline(transport=transport, policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")]) + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert get_token_calls == 1 # policy has no token at first request -> it should call get_token + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert get_token_calls == 1 # token is good for an hour -> policy should return it from cache + + expired_token = AccessToken("token", time.time()) + get_token_calls = 0 + expected_token = expired_token + pipeline = AsyncPipeline(transport=transport, policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")]) + + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert get_token_calls == 1 + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert get_token_calls == 2 # token expired -> policy should call get_token + + +async def test_optionally_enforces_https(): + """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" + + async def assert_option_popped(request, **kwargs): + assert "enforce_https" not in kwargs, "AsyncChallengeAuthenticationPolicy didn't pop the 'enforce_https' option" + return Mock() + + credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42))) + pipeline = AsyncPipeline(transport=Mock(send=assert_option_popped), policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")]) + + # by default and when enforce_https=True, the policy should raise when given an insecure request + with pytest.raises(ServiceRequestError): + await pipeline.run(HttpRequest("GET", "http://localhost")) + with pytest.raises(ServiceRequestError): + await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=True) + + # when enforce_https=False, an insecure request should pass + await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False) + + # https requests should always pass + await pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=False) + await pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=True) + await pipeline.run(HttpRequest("GET", "https://localhost")) + + +async def test_preserves_enforce_https_opt_out(): + """The policy should use request context to preserve an opt out from https enforcement""" + + class ContextValidator(SansIOHTTPPolicy): + def on_request(self, request): + assert "enforce_https" in request.context, "'enforce_https' is not in the request's context" + + async def send(_): + return Mock() + + transport = Mock(send=send) + + get_token = get_completed_future(AccessToken("***", 42)) + credential = Mock(get_token=lambda *_, **__: get_token) + policies = [AsyncChallengeAuthenticationPolicy(credential, "scope"), ContextValidator()] + pipeline = AsyncPipeline(transport=transport, policies=policies) + + await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False) + + +async def test_context_unmodified_by_default(): + """When no options for the policy accompany a request, the policy shouldn't add anything to the request context""" + + class ContextValidator(SansIOHTTPPolicy): + def on_request(self, request): + assert not any(request.context), "the policy shouldn't add to the request's context" + + async def send(_): + return Mock() + + transport = Mock(send=send) + + get_token = get_completed_future(AccessToken("***", 42)) + credential = Mock(get_token=lambda *_, **__: get_token) + policies = [AsyncChallengeAuthenticationPolicy(credential, "scope"), ContextValidator()] + pipeline = AsyncPipeline(transport=transport, policies=policies) + + await pipeline.run(HttpRequest("GET", "https://localhost")) + + +async def test_cannot_complete_challenge(): + """ChallengeAuthenticationPolicy should return the 401 response when it can't complete a challenge""" + + expected_response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) + + async def send(_): + return expected_response + + transport = Mock(send=Mock(wraps=send)) + + expected_scope = "scope" + get_token = Mock(return_value=get_completed_future(AccessToken("***", 42))) + credential = Mock(get_token=get_token) + policy = AsyncChallengeAuthenticationPolicy(credential, expected_scope) + policy.on_challenge = Mock(wraps=policy.on_challenge) + + pipeline = AsyncPipeline(transport=transport, policies=[policy]) + response = await pipeline.run(HttpRequest("GET", "https://localhost")) + + assert policy.on_challenge.called + assert response.http_response is expected_response + assert transport.send.call_count == 1 + credential.get_token.assert_called_once_with(expected_scope) + + +def get_completed_future(result=None): + fut = asyncio.Future() + fut.set_result(result) + return fut diff --git a/sdk/core/azure-core/tests/test_challenge_authentication.py b/sdk/core/azure-core/tests/test_challenge_authentication.py new file mode 100644 index 000000000000..f30d439a64ea --- /dev/null +++ b/sdk/core/azure-core/tests/test_challenge_authentication.py @@ -0,0 +1,194 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import time + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ServiceRequestError +from azure.core.pipeline import Pipeline +from azure.core.pipeline.policies import ChallengeAuthenticationPolicy, SansIOHTTPPolicy +from azure.core.pipeline.transport import HttpRequest + +import pytest + +try: + from unittest.mock import Mock +except ImportError: + # python < 3.3 + from mock import Mock + + +def test_adds_header(): + """The policy should add a header containing a token from its credential""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + expected_token = AccessToken("expected_token", 2524608000) + + def verify_authorization_header(request): + assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) + return Mock() + + fake_credential = Mock(get_token=Mock(return_value=expected_token)) + policy = ChallengeAuthenticationPolicy(fake_credential, "scope") + policies = [policy, Mock(send=verify_authorization_header)] + + pipeline = Pipeline(transport=Mock(), policies=policies) + pipeline.run(HttpRequest("GET", "https://localhost")) + + assert fake_credential.get_token.call_count == 1 + + pipeline.run(HttpRequest("GET", "https://localhost")) + + # Didn't need a new token + assert fake_credential.get_token.call_count == 1 + + +def test_default_context(): + """The policy should call get_token with the scopes given at construction, and no keyword arguments, by default""" + expected_scope = "scope" + token = AccessToken("", 0) + credential = Mock(get_token=Mock(return_value=token)) + policy = ChallengeAuthenticationPolicy(credential, expected_scope) + pipeline = Pipeline(transport=Mock(), policies=[policy]) + + pipeline.run(HttpRequest("GET", "https://localhost")) + + credential.get_token.assert_called_once_with(expected_scope) + + +def test_send(): + """The policy should invoke the next policy's send method and return the result""" + expected_request = HttpRequest("GET", "https://localhost") + expected_response = Mock() + + def verify_request(request): + assert request.http_request is expected_request + return expected_response + + fake_credential = Mock(get_token=lambda _: AccessToken("", 0)) + policy = ChallengeAuthenticationPolicy(fake_credential, "scope") + policies = [policy, Mock(send=verify_request)] + response = Pipeline(transport=Mock(), policies=policies).run(expected_request) + + assert response is expected_response + + +def test_token_caching(): + good_for_one_hour = AccessToken("token", time.time() + 3600) + credential = Mock(get_token=Mock(return_value=good_for_one_hour)) + policy = ChallengeAuthenticationPolicy(credential, "scope") + pipeline = Pipeline(transport=Mock(), policies=[policy]) + + pipeline.run(HttpRequest("GET", "https://localhost")) + assert credential.get_token.call_count == 1 # policy has no token at first request -> it should call get_token + + pipeline.run(HttpRequest("GET", "https://localhost")) + assert credential.get_token.call_count == 1 # token is good for an hour -> policy should return it from cache + + expired_token = AccessToken("token", time.time()) + credential.get_token.reset_mock() + credential.get_token.return_value = expired_token + pipeline = Pipeline(transport=Mock(), policies=[ChallengeAuthenticationPolicy(credential, "scope")]) + + pipeline.run(HttpRequest("GET", "https://localhost")) + assert credential.get_token.call_count == 1 + + pipeline.run(HttpRequest("GET", "https://localhost")) + assert credential.get_token.call_count == 2 # token expired -> policy should call get_token + + +def test_optionally_enforces_https(): + """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" + + def assert_option_popped(request, **kwargs): + assert "enforce_https" not in kwargs, "ChallengeAuthenticationPolicy didn't pop the 'enforce_https' option" + return Mock() + + credential = Mock(get_token=lambda *_, **__: AccessToken("***", 42)) + policy = ChallengeAuthenticationPolicy(credential, "scope") + pipeline = Pipeline(transport=Mock(send=assert_option_popped), policies=[policy]) + + # by default and when enforce_https=True, the policy should raise when given an insecure request + with pytest.raises(ServiceRequestError): + pipeline.run(HttpRequest("GET", "http://localhost")) + with pytest.raises(ServiceRequestError): + pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=True) + + # when enforce_https=False, an insecure request should pass + pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False) + + # https requests should always pass + pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=False) + pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=True) + pipeline.run(HttpRequest("GET", "https://localhost")) + + +def test_preserves_enforce_https_opt_out(): + """The policy should use request context to preserve an opt out from https enforcement""" + + class ContextValidator(SansIOHTTPPolicy): + def on_request(self, request): + assert "enforce_https" in request.context, "'enforce_https' is not in the request's context" + return Mock() + + credential = Mock(get_token=Mock(return_value=AccessToken("***", 42))) + policy = ChallengeAuthenticationPolicy(credential, "scope") + pipeline = Pipeline(transport=Mock(), policies=[policy]) + + pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False) + + +def test_context_unmodified_by_default(): + """When no options for the policy accompany a request, the policy shouldn't add anything to the request context""" + + class ContextValidator(SansIOHTTPPolicy): + def on_request(self, request): + assert not any(request.context), "the policy shouldn't add to the request's context" + + credential = Mock(get_token=Mock(return_value=AccessToken("***", 42))) + policy = ChallengeAuthenticationPolicy(credential, "scope") + policies = [policy, ContextValidator()] + pipeline = Pipeline(transport=Mock(), policies=policies) + + pipeline.run(HttpRequest("GET", "https://localhost")) + + +def test_cannot_complete_challenge(): + """ChallengeAuthenticationPolicy should return the 401 response when it can't complete a challenge""" + + expected_scope = "scope" + expected_token = AccessToken("***", int(time.time()) + 3600) + credential = Mock(get_token=Mock(return_value=expected_token)) + expected_response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) + transport = Mock(send=Mock(return_value=expected_response)) + policy = ChallengeAuthenticationPolicy(credential, expected_scope) + policy.on_challenge = Mock(wraps=policy.on_challenge) + + pipeline = Pipeline(transport=transport, policies=[policy]) + response = pipeline.run(HttpRequest("GET", "https://localhost")) + + assert policy.on_challenge.called + assert response.http_response is expected_response + assert transport.send.call_count == 1 + credential.get_token.assert_called_once_with(expected_scope)