From 75406a093a7426b862ca77325d40cc5f0cd4f8c2 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 8 Mar 2021 13:51:53 -0800 Subject: [PATCH 1/5] Async/ChallengeAuthenticationPolicy (azure-core) --- .../azure/core/pipeline/policies/__init__.py | 11 +- .../core/pipeline/policies/_authentication.py | 119 +++++++++++++++--- .../policies/_authentication_async.py | 99 ++++++++++++++- 3 files changed, 207 insertions(+), 22 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py b/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py index a0e81b13cef5..0f35696bdbf8 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py @@ -25,7 +25,12 @@ # -------------------------------------------------------------------------- from ._base import HTTPPolicy, SansIOHTTPPolicy, RequestHistory -from ._authentication import BearerTokenCredentialPolicy, AzureKeyCredentialPolicy, AzureSasCredentialPolicy +from ._authentication import ( + BearerTokenCredentialPolicy, + AzureKeyCredentialPolicy, + AzureSasCredentialPolicy, + ChallengeAuthenticationPolicy, +) from ._custom_hook import CustomHookPolicy from ._redirect import RedirectPolicy from ._retry import RetryPolicy, RetryMode @@ -44,6 +49,7 @@ 'HTTPPolicy', 'SansIOHTTPPolicy', 'BearerTokenCredentialPolicy', + 'ChallengeAuthenticationPolicy', 'AzureKeyCredentialPolicy', 'AzureSasCredentialPolicy', 'HeadersPolicy', @@ -65,12 +71,13 @@ try: from ._base_async import AsyncHTTPPolicy - from ._authentication_async import AsyncBearerTokenCredentialPolicy + from ._authentication_async import AsyncBearerTokenCredentialPolicy, AsyncChallengeAuthenticationPolicy from ._redirect_async import AsyncRedirectPolicy from ._retry_async import AsyncRetryPolicy __all__.extend([ 'AsyncHTTPPolicy', 'AsyncBearerTokenCredentialPolicy', + 'AsyncChallengeAuthenticationPolicy', 'AsyncRedirectPolicy', 'AsyncRetryPolicy' ]) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 929920033cdf..9915050c67b5 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -6,7 +6,7 @@ import time import six -from . import SansIOHTTPPolicy +from . import HTTPPolicy, SansIOHTTPPolicy from ...exceptions import ServiceRequestError try: @@ -18,7 +18,27 @@ # pylint:disable=unused-import from typing import Any, Dict, Optional from azure.core.credentials import AccessToken, TokenCredential, AzureKeyCredential, AzureSasCredential - from azure.core.pipeline import PipelineRequest + from azure.core.pipeline import PipelineRequest, PipelineResponse + + +def _enforce_https(request): + # type: (PipelineRequest) -> None + """Raise ServiceRequestError if the request URL is non-HTTPS and the sender did not specify "enforce_https=False" + """ + + # move 'enforce_https' from options to context so it persists + # across retries but isn't passed to a transport implementation + option = request.context.options.pop("enforce_https", None) + + # True is the default setting; we needn't preserve an explicit opt in to the default behavior + if option is False: + request.context["enforce_https"] = option + + enforce_https = request.context.get("enforce_https", True) + if enforce_https and not request.http_request.url.lower().startswith("https"): + raise ServiceRequestError( + "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs." + ) # pylint:disable=too-few-public-methods @@ -40,20 +60,7 @@ def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argu @staticmethod def _enforce_https(request): # type: (PipelineRequest) -> None - - # move 'enforce_https' from options to context so it persists - # across retries but isn't passed to a transport implementation - option = request.context.options.pop("enforce_https", None) - - # True is the default setting; we needn't preserve an explicit opt in to the default behavior - if option is False: - request.context["enforce_https"] = option - - enforce_https = request.context.get("enforce_https", True) - if enforce_https and not request.http_request.url.lower().startswith("https"): - raise ServiceRequestError( - "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs." - ) + return _enforce_https(request) @staticmethod def _update_headers(headers, token): @@ -94,6 +101,86 @@ def on_request(self, request): self._update_headers(request.http_request.headers, self._token.token) +class ChallengeAuthenticationPolicy(HTTPPolicy): + """Base class for policies that authorize requests with bearer tokens and expect authentication challenges + + :param ~azure.core.credentials.TokenCredential credential: an object which can provide access tokens, such as a + credential from :mod:`azure.identity` + :param str scopes: required authentication scopes + """ + + def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (TokenCredential, *str, **Any) -> None + super(ChallengeAuthenticationPolicy, self).__init__() + self._scopes = scopes + self._credential = credential + self._token = None # type: Optional[AccessToken] + + def _need_new_token(self): + # type: () -> bool + return not self._token or self._token.expires_on - time.time() < 300 + + def authorize_request(self, request, *scopes, **kwargs): + # type: (PipelineRequest, *str, **Any) -> None + """Acquire a token from the credential and authorize the request with it. + + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to + authorize future requests. + + :param ~azure.core.pipeline.PipelineRequest request: the request + :param str scopes: required scopes of authentication + """ + self._token = self._credential.get_token(*scopes, **kwargs) + request.http_request.headers["Authorization"] = "Bearer " + self._token.token + + def on_request(self, request): + # type: (PipelineRequest) -> None + """Called before the policy sends a request. + + The base implementation authorizes the request with a bearer token. + + :param ~azure.core.pipeline.PipelineRequest request: the request + """ + + if self._token is None or self._need_new_token(): + self._token = self._credential.get_token(*self._scopes) + request.http_request.headers["Authorization"] = "Bearer " + self._token.token + + def send(self, request): + # type: (PipelineRequest) -> PipelineResponse + """Authorizes a request with a bearer token, possibly handling an authentication challenge + + :param ~azure.core.pipeline.PipelineRequest request: the request + """ + _enforce_https(request) + + self.on_request(request) + + response = self.next.send(request) + + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + challenge = response.http_response.headers.get("WWW-Authenticate") + if challenge and self.on_challenge(request, response, challenge): + response = self.next.send(request) + + return response + + def on_challenge(self, request, response, challenge): + # type: (PipelineRequest, PipelineResponse, str) -> bool + """Authorize request according to an authentication challenge + + This method is called when the resource provider responds 401 with a WWW-Authenticate header. + + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response + :param str challenge: response's WWW-Authenticate header, unparsed. It may contain multiple challenges. + :returns: a bool indicating whether the policy should send the request + """ + # pylint:disable=unused-argument,no-self-use + return False + + class AzureKeyCredentialPolicy(SansIOHTTPPolicy): """Adds a key header for the provided credential. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index b300d15e5e78..719af6567935 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -4,10 +4,17 @@ # license information. # ------------------------------------------------------------------------- import asyncio +import time +from typing import TYPE_CHECKING -from azure.core.pipeline import PipelineRequest -from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase +from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy +from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase, _enforce_https + +if TYPE_CHECKING: + from typing import Any, Optional + from azure.core.credentials import AccessToken + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline import PipelineRequest, PipelineResponse class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy): @@ -23,7 +30,7 @@ def __init__(self, credential, *scopes, **kwargs): super().__init__(credential, *scopes, **kwargs) self._lock = asyncio.Lock() - async def on_request(self, request: PipelineRequest): # pylint:disable=invalid-overridden-method + async def on_request(self, request: "PipelineRequest"): # pylint:disable=invalid-overridden-method """Adds a bearer token Authorization header to request and sends request to next policy. :param request: The pipeline request object to be modified. @@ -36,3 +43,87 @@ async def on_request(self, request: PipelineRequest): # pylint:disable=invalid- if self._need_new_token: self._token = await self._credential.get_token(*self._scopes) # type: ignore self._update_headers(request.http_request.headers, self._token.token) + + +class AsyncChallengeAuthenticationPolicy(AsyncHTTPPolicy): + """Base class for policies that authorize requests with bearer tokens and expect authentication challenges + + :param ~azure.core.credentials.AsyncTokenCredential credential: an object which can asynchronously provide access + tokens, such as a credential from :mod:`azure.identity.aio` + :param str scopes: required authentication scopes + """ + + def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None: + # pylint:disable=unused-argument + super().__init__() + self._credential = credential + self._lock = asyncio.Lock() + self._scopes = scopes + self._token = None # type: Optional[AccessToken] + + def _need_new_token(self) -> bool: + return not self._token or self._token.expires_on - time.time() < 300 + + async def on_request(self, request: "PipelineRequest") -> None: + """Called before the policy sends a request. + + The base implementation authorizes the request with a bearer token. + + :param ~azure.core.pipeline.PipelineRequest request: the request + """ + + if self._token is None or self._need_new_token(): + async with self._lock: + # double check because another coroutine may have acquired a token while we waited to acquire the lock + if not self._token or self._need_new_token(): + self._token = await self._credential.get_token(*self._scopes) + + request.http_request.headers["Authorization"] = "Bearer " + self._token.token + + async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None: + """Acquire a token from the credential and authorize the request with it. + + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to + authorize future requests. + + :param ~azure.core.pipeline.PipelineRequest request: the request + :param str scopes: required scopes of authentication + """ + + async with self._lock: + self._token = await self._credential.get_token(*scopes, **kwargs) + request.http_request.headers["Authorization"] = "Bearer " + self._token.token + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + """Authorizes a request with a bearer token, possibly handling an authentication challenge + + :param ~azure.core.pipeline.PipelineRequest request: The request + """ + _enforce_https(request) + + await self.on_request(request) + + response = await self.next.send(request) + + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + challenge = response.http_response.headers.get("WWW-Authenticate") + if challenge: + request_authorized = await self.on_challenge(request, response, challenge) + if request_authorized: + response = await self.next.send(request) + + return response + + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse", challenge: str) -> bool: + """Authorize request according to an authentication challenge + + This method is called when the resource provider responds 401 with a WWW-Authenticate header. + + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response + :param str challenge: response's WWW-Authenticate header, unparsed. It may contain multiple challenges. + :returns: a bool indicating whether the policy should send the request + """ + # pylint:disable=unused-argument,no-self-use + return False From 420909f554eb6dbeb8f7207db5f92ad48256ae58 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 8 Mar 2021 13:52:13 -0800 Subject: [PATCH 2/5] tests (azure-core) --- .../test_challenge_authentication_async.py | 242 ++++++++++++++++++ .../tests/test_challenge_authentication.py | 205 +++++++++++++++ 2 files changed, 447 insertions(+) create mode 100644 sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py create mode 100644 sdk/core/azure-core/tests/test_challenge_authentication.py diff --git a/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py new file mode 100644 index 000000000000..10ccdd9919e9 --- /dev/null +++ b/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py @@ -0,0 +1,242 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import asyncio +import base64 +import itertools +import time +from unittest.mock import Mock + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ServiceRequestError +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.policies import AsyncChallengeAuthenticationPolicy, SansIOHTTPPolicy +from azure.core.pipeline.transport import HttpRequest + +import pytest + +pytestmark = pytest.mark.asyncio + + +class MockPolicy(AsyncChallengeAuthenticationPolicy): + def __init__(self, *args, **kwargs): + super(MockPolicy, self).__init__(*args, **kwargs) + self.on_challenge_called = False + + async def on_challenge(self, request, response, challenge): + self.on_challenge_called = True + return False + + +async def test_adds_header(): + """The bearer token policy should add a header containing a token from its credential""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + expected_token = AccessToken("expected_token", 2524608000) + + async def verify_authorization_header(request): + assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) + return Mock() + + get_token_calls = 0 + + async def get_token(_): + nonlocal get_token_calls + get_token_calls += 1 + return expected_token + + fake_credential = Mock(get_token=get_token) + policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] + pipeline = AsyncPipeline(transport=Mock(), policies=policies) + + await pipeline.run(HttpRequest("GET", "https://localhost"), context=None) + assert get_token_calls == 1 + + await pipeline.run(HttpRequest("GET", "https://localhost"), context=None) + # Didn't need a new token + assert get_token_calls == 1 + + +async def test_default_context(): + """The policy should call get_token with the scopes given at construction, and no keyword arguments, by default""" + + async def send(_): + return Mock() + + async def get_token(_): + return AccessToken("", 0) + + expected_scope = "scope" + credential = Mock(get_token=Mock(wraps=get_token)) + policy = MockPolicy(credential, expected_scope) + pipeline = AsyncPipeline(transport=Mock(send=send), policies=[policy]) + + await pipeline.run(HttpRequest("GET", "https://localhost")) + + credential.get_token.assert_called_once_with(expected_scope) + + +async def test_send(): + """The bearer token policy should invoke the next policy's send method and return the result""" + expected_request = HttpRequest("GET", "https://localhost") + expected_response = Mock() + + async def verify_request(request): + assert request.http_request is expected_request + return expected_response + + fake_credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("", 0))) + policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_request)] + response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request) + + assert response is expected_response + + +async def test_token_caching(): + good_for_one_hour = AccessToken("token", time.time() + 3600) + expected_token = good_for_one_hour + get_token_calls = 0 + + async def get_token(_): + nonlocal get_token_calls + get_token_calls += 1 + return expected_token + + credential = Mock(get_token=get_token) + + async def send(_): + return Mock() + + transport = Mock(send=send) + + pipeline = AsyncPipeline(transport=transport, policies=[MockPolicy(credential, "scope")]) + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert get_token_calls == 1 # policy has no token at first request -> it should call get_token + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert get_token_calls == 1 # token is good for an hour -> policy should return it from cache + + expired_token = AccessToken("token", time.time()) + get_token_calls = 0 + expected_token = expired_token + pipeline = AsyncPipeline(transport=transport, policies=[MockPolicy(credential, "scope")]) + + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert get_token_calls == 1 + await pipeline.run(HttpRequest("GET", "https://localhost")) + assert get_token_calls == 2 # token expired -> policy should call get_token + + +async def test_optionally_enforces_https(): + """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" + + async def assert_option_popped(request, **kwargs): + assert "enforce_https" not in kwargs, "MockPolicy didn't pop the 'enforce_https' option" + return Mock() + + credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42))) + pipeline = AsyncPipeline(transport=Mock(send=assert_option_popped), policies=[MockPolicy(credential, "scope")]) + + # by default and when enforce_https=True, the policy should raise when given an insecure request + with pytest.raises(ServiceRequestError): + await pipeline.run(HttpRequest("GET", "http://not.secure")) + with pytest.raises(ServiceRequestError): + await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) + + # when enforce_https=False, an insecure request should pass + await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + + # https requests should always pass + await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) + await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) + await pipeline.run(HttpRequest("GET", "https://secure")) + + +async def test_preserves_enforce_https_opt_out(): + """The policy should use request context to preserve an opt out from https enforcement""" + + class ContextValidator(SansIOHTTPPolicy): + def on_request(self, request): + assert "enforce_https" in request.context, "'enforce_https' is not in the request's context" + + async def send(_): + return Mock() + + transport = Mock(send=send) + + get_token = get_completed_future(AccessToken("***", 42)) + credential = Mock(get_token=lambda *_, **__: get_token) + policies = [MockPolicy(credential, "scope"), ContextValidator()] + pipeline = AsyncPipeline(transport=transport, policies=policies) + + await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + + +async def test_context_unmodified_by_default(): + """When no options for the policy accompany a request, the policy shouldn't add anything to the request context""" + + class ContextValidator(SansIOHTTPPolicy): + def on_request(self, request): + assert not any(request.context), "the policy shouldn't add to the request's context" + + async def send(_): + return Mock() + + transport = Mock(send=send) + + get_token = get_completed_future(AccessToken("***", 42)) + credential = Mock(get_token=lambda *_, **__: get_token) + policies = [MockPolicy(credential, "scope"), ContextValidator()] + pipeline = AsyncPipeline(transport=transport, policies=policies) + + await pipeline.run(HttpRequest("GET", "https://secure")) + + +async def test_cannot_complete_challenge(): + """ChallengeAuthenticationPolicy should return the 401 response when it can't complete a challenge""" + + expected_response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) + + async def send(_): + return expected_response + + transport = Mock(send=Mock(wraps=send)) + + expected_scope = "scope" + get_token = Mock(return_value=get_completed_future(AccessToken("***", 42))) + credential = Mock(get_token=get_token) + policy = MockPolicy(credential, expected_scope) + + pipeline = AsyncPipeline(transport=transport, policies=[policy]) + response = await pipeline.run(HttpRequest("GET", "https://localhost")) + + assert policy.on_challenge_called + assert response.http_response is expected_response + assert transport.send.call_count == 1 + credential.get_token.assert_called_once_with(expected_scope) + + +def get_completed_future(result=None): + fut = asyncio.Future() + fut.set_result(result) + return fut diff --git a/sdk/core/azure-core/tests/test_challenge_authentication.py b/sdk/core/azure-core/tests/test_challenge_authentication.py new file mode 100644 index 000000000000..ecc2734f61ee --- /dev/null +++ b/sdk/core/azure-core/tests/test_challenge_authentication.py @@ -0,0 +1,205 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- +import base64 +import itertools +import time + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ServiceRequestError +from azure.core.pipeline import Pipeline +from azure.core.pipeline.policies import ChallengeAuthenticationPolicy, SansIOHTTPPolicy +from azure.core.pipeline.transport import HttpRequest + +import pytest + +try: + from unittest.mock import Mock +except ImportError: + # python < 3.3 + from mock import Mock + + +class MockPolicy(ChallengeAuthenticationPolicy): + def __init__(self, *args, **kwargs): + super(MockPolicy, self).__init__(*args, **kwargs) + self.on_challenge_called = False + + def on_challenge(self, request, response, challenge): + self.on_challenge_called = True + return False + + +def test_adds_header(): + """The policy should add a header containing a token from its credential""" + # 2524608000 == 01/01/2050 @ 12:00am (UTC) + expected_token = AccessToken("expected_token", 2524608000) + + def verify_authorization_header(request): + assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token) + return Mock() + + fake_credential = Mock(get_token=Mock(return_value=expected_token)) + policy = MockPolicy(fake_credential, "scope") + policies = [policy, Mock(send=verify_authorization_header)] + + pipeline = Pipeline(transport=Mock(), policies=policies) + pipeline.run(HttpRequest("GET", "https://spam.eggs")) + + assert fake_credential.get_token.call_count == 1 + + pipeline.run(HttpRequest("GET", "https://spam.eggs")) + + # Didn't need a new token + assert fake_credential.get_token.call_count == 1 + + +def test_default_context(): + """The policy should call get_token with the scopes given at construction, and no keyword arguments, by default""" + expected_scope = "scope" + token = AccessToken("", 0) + credential = Mock(get_token=Mock(return_value=token)) + policy = MockPolicy(credential, expected_scope) + pipeline = Pipeline(transport=Mock(), policies=[policy]) + + pipeline.run(HttpRequest("GET", "https://localhost")) + + credential.get_token.assert_called_once_with(expected_scope) + + +def test_send(): + """The policy should invoke the next policy's send method and return the result""" + expected_request = HttpRequest("GET", "https://spam.eggs") + expected_response = Mock() + + def verify_request(request): + assert request.http_request is expected_request + return expected_response + + fake_credential = Mock(get_token=lambda _: AccessToken("", 0)) + policy = MockPolicy(fake_credential, "scope") + policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_request)] + response = Pipeline(transport=Mock(), policies=policies).run(expected_request) + + assert response is expected_response + + +def test_token_caching(): + good_for_one_hour = AccessToken("token", time.time() + 3600) + credential = Mock(get_token=Mock(return_value=good_for_one_hour)) + policy = MockPolicy(credential, "scope") + pipeline = Pipeline(transport=Mock(), policies=[policy]) + + pipeline.run(HttpRequest("GET", "https://spam.eggs")) + assert credential.get_token.call_count == 1 # policy has no token at first request -> it should call get_token + + pipeline.run(HttpRequest("GET", "https://spam.eggs")) + assert credential.get_token.call_count == 1 # token is good for an hour -> policy should return it from cache + + expired_token = AccessToken("token", time.time()) + credential.get_token.reset_mock() + credential.get_token.return_value = expired_token + pipeline = Pipeline(transport=Mock(), policies=[MockPolicy(credential, "scope")]) + + pipeline.run(HttpRequest("GET", "https://spam.eggs")) + assert credential.get_token.call_count == 1 + + pipeline.run(HttpRequest("GET", "https://spam.eggs")) + assert credential.get_token.call_count == 2 # token expired -> policy should call get_token + + +def test_optionally_enforces_https(): + """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" + + def assert_option_popped(request, **kwargs): + assert "enforce_https" not in kwargs, "ChallengeAuthenticationPolicy didn't pop the 'enforce_https' option" + return Mock() + + credential = Mock(get_token=lambda *_, **__: AccessToken("***", 42)) + policy = MockPolicy(credential, "scope") + pipeline = Pipeline(transport=Mock(send=assert_option_popped), policies=[policy]) + + # by default and when enforce_https=True, the policy should raise when given an insecure request + with pytest.raises(ServiceRequestError): + pipeline.run(HttpRequest("GET", "http://not.secure")) + with pytest.raises(ServiceRequestError): + pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) + + # when enforce_https=False, an insecure request should pass + pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + + # https requests should always pass + pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) + pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) + pipeline.run(HttpRequest("GET", "https://secure")) + + +def test_preserves_enforce_https_opt_out(): + """The policy should use request context to preserve an opt out from https enforcement""" + + class ContextValidator(SansIOHTTPPolicy): + def on_request(self, request): + assert "enforce_https" in request.context, "'enforce_https' is not in the request's context" + return Mock() + + credential = Mock(get_token=Mock(return_value=AccessToken("***", 42))) + policy = MockPolicy(credential, "scope") + pipeline = Pipeline(transport=Mock(), policies=[policy]) + + pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + + +def test_context_unmodified_by_default(): + """When no options for the policy accompany a request, the policy shouldn't add anything to the request context""" + + class ContextValidator(SansIOHTTPPolicy): + def on_request(self, request): + assert not any(request.context), "the policy shouldn't add to the request's context" + + credential = Mock(get_token=Mock(return_value=AccessToken("***", 42))) + policy = MockPolicy(credential, "scope") + policies = [policy, ContextValidator()] + pipeline = Pipeline(transport=Mock(), policies=policies) + + pipeline.run(HttpRequest("GET", "https://secure")) + + +def test_cannot_complete_challenge(): + """ChallengeAuthenticationPolicy should return the 401 response when it can't complete a challenge""" + + expected_scope = "scope" + expected_token = AccessToken("***", int(time.time()) + 3600) + credential = Mock(get_token=Mock(return_value=expected_token)) + expected_response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) + transport = Mock(send=Mock(return_value=expected_response)) + policy = MockPolicy(credential, expected_scope) + + pipeline = Pipeline(transport=transport, policies=[policy]) + response = pipeline.run(HttpRequest("GET", "https://localhost")) + + assert policy.on_challenge_called + assert response.http_response is expected_response + assert transport.send.call_count == 1 + credential.get_token.assert_called_once_with(expected_scope) From 2f524802e8cac8cc36702a9e8ba836eb1e76dbd5 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 2 Apr 2021 15:32:00 -0700 Subject: [PATCH 3/5] package metadata (azure-core) --- sdk/core/azure-core/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 4f5feaa9a18a..3906971752ca 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -5,6 +5,8 @@ ### New Features - Added `azure.core.credentials.AzureNamedKeyCredential` credential #17548. +- Added `azure.core.pipeline.policies.ChallengeAuthenticationPolicy` and + `.AsyncChallengeAuthenticationPolicy` ## 1.13.0 (2021-04-02) From 8938731f1d00bf22bcd7f58ee0884ab1ed7f9ad2 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 7 Apr 2021 14:58:21 -0700 Subject: [PATCH 4/5] remove 'challenge' parameter from on_challenge (core) --- .../azure/core/pipeline/policies/_authentication.py | 8 +++----- .../azure/core/pipeline/policies/_authentication_async.py | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 9915050c67b5..d997569867f1 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -160,21 +160,19 @@ def send(self, request): if response.http_response.status_code == 401: self._token = None # any cached token is invalid - challenge = response.http_response.headers.get("WWW-Authenticate") - if challenge and self.on_challenge(request, response, challenge): + if "WWW-Authenticate" in response.http_response.headers and self.on_challenge(request, response): response = self.next.send(request) return response - def on_challenge(self, request, response, challenge): - # type: (PipelineRequest, PipelineResponse, str) -> bool + def on_challenge(self, request, response): + # type: (PipelineRequest, PipelineResponse) -> bool """Authorize request according to an authentication challenge This method is called when the resource provider responds 401 with a WWW-Authenticate header. :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response - :param str challenge: response's WWW-Authenticate header, unparsed. It may contain multiple challenges. :returns: a bool indicating whether the policy should send the request """ # pylint:disable=unused-argument,no-self-use diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 719af6567935..0519ebfcbee8 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -107,22 +107,20 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": if response.http_response.status_code == 401: self._token = None # any cached token is invalid - challenge = response.http_response.headers.get("WWW-Authenticate") - if challenge: - request_authorized = await self.on_challenge(request, response, challenge) + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) if request_authorized: response = await self.next.send(request) return response - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse", challenge: str) -> bool: + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Authorize request according to an authentication challenge This method is called when the resource provider responds 401 with a WWW-Authenticate header. :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response - :param str challenge: response's WWW-Authenticate header, unparsed. It may contain multiple challenges. :returns: a bool indicating whether the policy should send the request """ # pylint:disable=unused-argument,no-self-use From 15e2521cc41404a1ab80d06d76efbd40e323385b Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 7 Apr 2021 15:14:04 -0700 Subject: [PATCH 5/5] update tests (core) --- .../test_challenge_authentication_async.py | 51 ++++++--------- .../tests/test_challenge_authentication.py | 65 ++++++++----------- 2 files changed, 47 insertions(+), 69 deletions(-) diff --git a/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py index 10ccdd9919e9..827476dc868b 100644 --- a/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py @@ -24,8 +24,6 @@ # # -------------------------------------------------------------------------- import asyncio -import base64 -import itertools import time from unittest.mock import Mock @@ -40,16 +38,6 @@ pytestmark = pytest.mark.asyncio -class MockPolicy(AsyncChallengeAuthenticationPolicy): - def __init__(self, *args, **kwargs): - super(MockPolicy, self).__init__(*args, **kwargs) - self.on_challenge_called = False - - async def on_challenge(self, request, response, challenge): - self.on_challenge_called = True - return False - - async def test_adds_header(): """The bearer token policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) @@ -67,7 +55,7 @@ async def get_token(_): return expected_token fake_credential = Mock(get_token=get_token) - policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] + policies = [AsyncChallengeAuthenticationPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = AsyncPipeline(transport=Mock(), policies=policies) await pipeline.run(HttpRequest("GET", "https://localhost"), context=None) @@ -89,7 +77,7 @@ async def get_token(_): expected_scope = "scope" credential = Mock(get_token=Mock(wraps=get_token)) - policy = MockPolicy(credential, expected_scope) + policy = AsyncChallengeAuthenticationPolicy(credential, expected_scope) pipeline = AsyncPipeline(transport=Mock(send=send), policies=[policy]) await pipeline.run(HttpRequest("GET", "https://localhost")) @@ -107,7 +95,7 @@ async def verify_request(request): return expected_response fake_credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("", 0))) - policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_request)] + policies = [AsyncChallengeAuthenticationPolicy(fake_credential, "scope"), Mock(send=verify_request)] response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request) assert response is expected_response @@ -130,7 +118,7 @@ async def send(_): transport = Mock(send=send) - pipeline = AsyncPipeline(transport=transport, policies=[MockPolicy(credential, "scope")]) + pipeline = AsyncPipeline(transport=transport, policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")]) await pipeline.run(HttpRequest("GET", "https://localhost")) assert get_token_calls == 1 # policy has no token at first request -> it should call get_token await pipeline.run(HttpRequest("GET", "https://localhost")) @@ -139,7 +127,7 @@ async def send(_): expired_token = AccessToken("token", time.time()) get_token_calls = 0 expected_token = expired_token - pipeline = AsyncPipeline(transport=transport, policies=[MockPolicy(credential, "scope")]) + pipeline = AsyncPipeline(transport=transport, policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")]) await pipeline.run(HttpRequest("GET", "https://localhost")) assert get_token_calls == 1 @@ -151,25 +139,25 @@ async def test_optionally_enforces_https(): """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" async def assert_option_popped(request, **kwargs): - assert "enforce_https" not in kwargs, "MockPolicy didn't pop the 'enforce_https' option" + assert "enforce_https" not in kwargs, "AsyncChallengeAuthenticationPolicy didn't pop the 'enforce_https' option" return Mock() credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42))) - pipeline = AsyncPipeline(transport=Mock(send=assert_option_popped), policies=[MockPolicy(credential, "scope")]) + pipeline = AsyncPipeline(transport=Mock(send=assert_option_popped), policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")]) # by default and when enforce_https=True, the policy should raise when given an insecure request with pytest.raises(ServiceRequestError): - await pipeline.run(HttpRequest("GET", "http://not.secure")) + await pipeline.run(HttpRequest("GET", "http://localhost")) with pytest.raises(ServiceRequestError): - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) + await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=True) # when enforce_https=False, an insecure request should pass - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False) # https requests should always pass - await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) - await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) - await pipeline.run(HttpRequest("GET", "https://secure")) + await pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=False) + await pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=True) + await pipeline.run(HttpRequest("GET", "https://localhost")) async def test_preserves_enforce_https_opt_out(): @@ -186,10 +174,10 @@ async def send(_): get_token = get_completed_future(AccessToken("***", 42)) credential = Mock(get_token=lambda *_, **__: get_token) - policies = [MockPolicy(credential, "scope"), ContextValidator()] + policies = [AsyncChallengeAuthenticationPolicy(credential, "scope"), ContextValidator()] pipeline = AsyncPipeline(transport=transport, policies=policies) - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False) async def test_context_unmodified_by_default(): @@ -206,10 +194,10 @@ async def send(_): get_token = get_completed_future(AccessToken("***", 42)) credential = Mock(get_token=lambda *_, **__: get_token) - policies = [MockPolicy(credential, "scope"), ContextValidator()] + policies = [AsyncChallengeAuthenticationPolicy(credential, "scope"), ContextValidator()] pipeline = AsyncPipeline(transport=transport, policies=policies) - await pipeline.run(HttpRequest("GET", "https://secure")) + await pipeline.run(HttpRequest("GET", "https://localhost")) async def test_cannot_complete_challenge(): @@ -225,12 +213,13 @@ async def send(_): expected_scope = "scope" get_token = Mock(return_value=get_completed_future(AccessToken("***", 42))) credential = Mock(get_token=get_token) - policy = MockPolicy(credential, expected_scope) + policy = AsyncChallengeAuthenticationPolicy(credential, expected_scope) + policy.on_challenge = Mock(wraps=policy.on_challenge) pipeline = AsyncPipeline(transport=transport, policies=[policy]) response = await pipeline.run(HttpRequest("GET", "https://localhost")) - assert policy.on_challenge_called + assert policy.on_challenge.called assert response.http_response is expected_response assert transport.send.call_count == 1 credential.get_token.assert_called_once_with(expected_scope) diff --git a/sdk/core/azure-core/tests/test_challenge_authentication.py b/sdk/core/azure-core/tests/test_challenge_authentication.py index ecc2734f61ee..f30d439a64ea 100644 --- a/sdk/core/azure-core/tests/test_challenge_authentication.py +++ b/sdk/core/azure-core/tests/test_challenge_authentication.py @@ -23,8 +23,6 @@ # THE SOFTWARE. # # -------------------------------------------------------------------------- -import base64 -import itertools import time from azure.core.credentials import AccessToken @@ -42,16 +40,6 @@ from mock import Mock -class MockPolicy(ChallengeAuthenticationPolicy): - def __init__(self, *args, **kwargs): - super(MockPolicy, self).__init__(*args, **kwargs) - self.on_challenge_called = False - - def on_challenge(self, request, response, challenge): - self.on_challenge_called = True - return False - - def test_adds_header(): """The policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) @@ -62,15 +50,15 @@ def verify_authorization_header(request): return Mock() fake_credential = Mock(get_token=Mock(return_value=expected_token)) - policy = MockPolicy(fake_credential, "scope") + policy = ChallengeAuthenticationPolicy(fake_credential, "scope") policies = [policy, Mock(send=verify_authorization_header)] pipeline = Pipeline(transport=Mock(), policies=policies) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(HttpRequest("GET", "https://localhost")) assert fake_credential.get_token.call_count == 1 - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(HttpRequest("GET", "https://localhost")) # Didn't need a new token assert fake_credential.get_token.call_count == 1 @@ -81,7 +69,7 @@ def test_default_context(): expected_scope = "scope" token = AccessToken("", 0) credential = Mock(get_token=Mock(return_value=token)) - policy = MockPolicy(credential, expected_scope) + policy = ChallengeAuthenticationPolicy(credential, expected_scope) pipeline = Pipeline(transport=Mock(), policies=[policy]) pipeline.run(HttpRequest("GET", "https://localhost")) @@ -91,7 +79,7 @@ def test_default_context(): def test_send(): """The policy should invoke the next policy's send method and return the result""" - expected_request = HttpRequest("GET", "https://spam.eggs") + expected_request = HttpRequest("GET", "https://localhost") expected_response = Mock() def verify_request(request): @@ -99,8 +87,8 @@ def verify_request(request): return expected_response fake_credential = Mock(get_token=lambda _: AccessToken("", 0)) - policy = MockPolicy(fake_credential, "scope") - policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_request)] + policy = ChallengeAuthenticationPolicy(fake_credential, "scope") + policies = [policy, Mock(send=verify_request)] response = Pipeline(transport=Mock(), policies=policies).run(expected_request) assert response is expected_response @@ -109,24 +97,24 @@ def verify_request(request): def test_token_caching(): good_for_one_hour = AccessToken("token", time.time() + 3600) credential = Mock(get_token=Mock(return_value=good_for_one_hour)) - policy = MockPolicy(credential, "scope") + policy = ChallengeAuthenticationPolicy(credential, "scope") pipeline = Pipeline(transport=Mock(), policies=[policy]) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(HttpRequest("GET", "https://localhost")) assert credential.get_token.call_count == 1 # policy has no token at first request -> it should call get_token - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(HttpRequest("GET", "https://localhost")) assert credential.get_token.call_count == 1 # token is good for an hour -> policy should return it from cache expired_token = AccessToken("token", time.time()) credential.get_token.reset_mock() credential.get_token.return_value = expired_token - pipeline = Pipeline(transport=Mock(), policies=[MockPolicy(credential, "scope")]) + pipeline = Pipeline(transport=Mock(), policies=[ChallengeAuthenticationPolicy(credential, "scope")]) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(HttpRequest("GET", "https://localhost")) assert credential.get_token.call_count == 1 - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(HttpRequest("GET", "https://localhost")) assert credential.get_token.call_count == 2 # token expired -> policy should call get_token @@ -138,22 +126,22 @@ def assert_option_popped(request, **kwargs): return Mock() credential = Mock(get_token=lambda *_, **__: AccessToken("***", 42)) - policy = MockPolicy(credential, "scope") + policy = ChallengeAuthenticationPolicy(credential, "scope") pipeline = Pipeline(transport=Mock(send=assert_option_popped), policies=[policy]) # by default and when enforce_https=True, the policy should raise when given an insecure request with pytest.raises(ServiceRequestError): - pipeline.run(HttpRequest("GET", "http://not.secure")) + pipeline.run(HttpRequest("GET", "http://localhost")) with pytest.raises(ServiceRequestError): - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) + pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=True) # when enforce_https=False, an insecure request should pass - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False) # https requests should always pass - pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) - pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) - pipeline.run(HttpRequest("GET", "https://secure")) + pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=False) + pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=True) + pipeline.run(HttpRequest("GET", "https://localhost")) def test_preserves_enforce_https_opt_out(): @@ -165,10 +153,10 @@ def on_request(self, request): return Mock() credential = Mock(get_token=Mock(return_value=AccessToken("***", 42))) - policy = MockPolicy(credential, "scope") + policy = ChallengeAuthenticationPolicy(credential, "scope") pipeline = Pipeline(transport=Mock(), policies=[policy]) - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False) def test_context_unmodified_by_default(): @@ -179,11 +167,11 @@ def on_request(self, request): assert not any(request.context), "the policy shouldn't add to the request's context" credential = Mock(get_token=Mock(return_value=AccessToken("***", 42))) - policy = MockPolicy(credential, "scope") + policy = ChallengeAuthenticationPolicy(credential, "scope") policies = [policy, ContextValidator()] pipeline = Pipeline(transport=Mock(), policies=policies) - pipeline.run(HttpRequest("GET", "https://secure")) + pipeline.run(HttpRequest("GET", "https://localhost")) def test_cannot_complete_challenge(): @@ -194,12 +182,13 @@ def test_cannot_complete_challenge(): credential = Mock(get_token=Mock(return_value=expected_token)) expected_response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) transport = Mock(send=Mock(return_value=expected_response)) - policy = MockPolicy(credential, expected_scope) + policy = ChallengeAuthenticationPolicy(credential, expected_scope) + policy.on_challenge = Mock(wraps=policy.on_challenge) pipeline = Pipeline(transport=transport, policies=[policy]) response = pipeline.run(HttpRequest("GET", "https://localhost")) - assert policy.on_challenge_called + assert policy.on_challenge.called assert response.http_response is expected_response assert transport.send.call_count == 1 credential.get_token.assert_called_once_with(expected_scope)