Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import six

from . import SansIOHTTPPolicy
from . import HTTPPolicy, SansIOHTTPPolicy
from ...exceptions import ServiceRequestError

try:
Expand All @@ -18,7 +18,7 @@
# pylint:disable=unused-import
from typing import Any, Dict, Optional
from azure.core.credentials import AccessToken, TokenCredential, AzureKeyCredential, AzureSasCredential
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline import PipelineRequest, PipelineResponse


# pylint:disable=too-few-public-methods
Expand Down Expand Up @@ -71,7 +71,7 @@ def _need_new_token(self):
return not self._token or self._token.expires_on - time.time() < 300


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy, HTTPPolicy):
Copy link
Member

@xiangyan99 xiangyan99 May 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the pipeline will call on_request or send or both?

Do we have a design for this scenario? @annatisch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After #16726, the pipeline will call this policy's send method (instead of _SansIOHTTPPolicyRunner.send). Another option is for this policy not to inherit SansIOHTTPPolicy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an "AvecIOHTTPPolicy" (as opposed to Sans), I would remove the inheritance from SansIOHTTPPolicy and derive from the "correct" base class. I think that your implementation can call self.on_request and self.on_response to make sure they are called in a very similar fashion as before in order to handle any cases where someone has extended/derived from the policy, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have done. The extended policy follows _SansIOHTTPPolicyRunner's calling pattern such that it calls a subclass's on_* as the runner would.

"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
Expand All @@ -82,16 +82,68 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPo

def on_request(self, request):
# type: (PipelineRequest) -> None
"""Adds a bearer token Authorization header to request and sends request to next policy.
"""Called before the policy sends a request.

The base implementation authorizes the request with a bearer token.

:param ~azure.core.pipeline.PipelineRequest request: the request
"""
if self._token is None or self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)

def authorize_request(self, request, *scopes, **kwargs):
# type: (PipelineRequest, *str, **Any) -> None
"""Acquire a token from the credential and authorize the request with it.

Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
authorize future requests.

:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
self._token = self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

def send(self, request):
# type: (PipelineRequest) -> PipelineResponse
"""Authorize request with a bearer token and send it to the next policy

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
self._enforce_https(request)
self.on_request(request)
try:
response = self.next.send(request)
self.on_response(request, response)
except Exception: # pylint:disable=broad-except
handled = self.on_exception(request)
if not handled:
raise
else:
if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = self.on_challenge(request, response)
if request_authorized:
response = self.next.send(request)
self.on_response(request, response)

if self._token is None or self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)
return response

def on_challenge(self, request, response):
# type: (PipelineRequest, PipelineResponse) -> bool
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
# pylint:disable=unused-argument,no-self-use
return False


class AzureKeyCredentialPolicy(SansIOHTTPPolicy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,101 @@
# license information.
# -------------------------------------------------------------------------
import asyncio
import time
from typing import TYPE_CHECKING

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase

from .._tools_async import await_result

class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
# pylint:disable=too-few-public-methods
if TYPE_CHECKING:
from typing import Any, Optional
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest, PipelineResponse


class AsyncBearerTokenCredentialPolicy(SansIOHTTPPolicy, AsyncHTTPPolicy):
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs):
super().__init__(credential, *scopes, **kwargs)
def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None:
# pylint:disable=unused-argument
super().__init__()
self._credential = credential
self._lock = asyncio.Lock()
self._scopes = scopes
self._token = None # type: Optional[AccessToken]

async def on_request(self, request: PipelineRequest): # pylint:disable=invalid-overridden-method
async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._enforce_https(request)
if self._token is None or self._need_new_token():
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token():
self._token = await self._credential.get_token(*self._scopes)
request.http_request.headers["Authorization"] = "Bearer " + self._token.token

async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None:
"""Acquire a token from the credential and authorize the request with it.

Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
authorize future requests.

:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
async with self._lock:
if self._need_new_token:
self._token = await self._credential.get_token(*self._scopes) # type: ignore
self._update_headers(request.http_request.headers, self._token.token)
self._token = await self._credential.get_token(*scopes, **kwargs)
request.http_request.headers["Authorization"] = "Bearer " + self._token.token

async def send(self, request: "PipelineRequest") -> "PipelineResponse":
"""Authorize request with a bearer token and send it to the next policy

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
_BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
await await_result(self.on_request, request)
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise
else:
if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = await self.on_challenge(request, response)
if request_authorized:
response = await self.next.send(request)
await await_result(self.on_response, request, response)

return response

async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
# pylint:disable=unused-argument,no-self-use
return False

def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
66 changes: 53 additions & 13 deletions sdk/core/azure-core/tests/async_tests/test_authentication_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@
from unittest.mock import Mock

from azure.core.credentials import AccessToken
from azure.core.exceptions import AzureError, ServiceRequestError
from azure.core.exceptions import ServiceRequestError
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import HttpRequest
import pytest

pytestmark = pytest.mark.asyncio


@pytest.mark.asyncio
async def test_bearer_policy_adds_header():
"""The bearer token policy should add a header containing a token from its credential"""
# 2524608000 == 01/01/2050 @ 12:00am (UTC)
expected_token = AccessToken("expected_token", 2524608000)

async def verify_authorization_header(request):
assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token)
return Mock()

get_token_calls = 0

Expand All @@ -43,7 +45,6 @@ async def get_token(_):
assert get_token_calls == 1


@pytest.mark.asyncio
async def test_bearer_policy_send():
"""The bearer token policy should invoke the next policy's send method and return the result"""
expected_request = HttpRequest("GET", "https://spam.eggs")
Expand All @@ -60,7 +61,6 @@ async def verify_request(request):
assert response is expected_response


@pytest.mark.asyncio
async def test_bearer_policy_token_caching():
good_for_one_hour = AccessToken("token", time.time() + 3600)
expected_token = good_for_one_hour
Expand All @@ -74,7 +74,7 @@ async def get_token(_):
credential = Mock(get_token=get_token)
policies = [
AsyncBearerTokenCredentialPolicy(credential, "scope"),
Mock(send=Mock(return_value=get_completed_future())),
Mock(send=Mock(return_value=get_completed_future(Mock()))),
]
pipeline = AsyncPipeline(transport=Mock, policies=policies)

Expand All @@ -87,7 +87,10 @@ async def get_token(_):
expired_token = AccessToken("token", time.time())
get_token_calls = 0
expected_token = expired_token
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), Mock(send=lambda _: get_completed_future())]
policies = [
AsyncBearerTokenCredentialPolicy(credential, "scope"),
Mock(send=lambda _: get_completed_future(Mock())),
]
pipeline = AsyncPipeline(transport=Mock(), policies=policies)

await pipeline.run(HttpRequest("GET", "https://spam.eggs"))
Expand All @@ -97,12 +100,12 @@ async def get_token(_):
assert get_token_calls == 2 # token expired -> policy should call get_token


@pytest.mark.asyncio
async def test_bearer_policy_optionally_enforces_https():
"""HTTPS enforcement should be controlled by a keyword argument, and enabled by default"""

async def assert_option_popped(request, **kwargs):
assert "enforce_https" not in kwargs, "AsyncBearerTokenCredentialPolicy didn't pop the 'enforce_https' option"
return Mock()

credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42)))
pipeline = AsyncPipeline(
Expand All @@ -124,38 +127,75 @@ async def assert_option_popped(request, **kwargs):
await pipeline.run(HttpRequest("GET", "https://secure"))


@pytest.mark.asyncio
async def test_preserves_enforce_https_opt_out():
async def test_bearer_policy_preserves_enforce_https_opt_out():
"""The policy should use request context to preserve an opt out from https enforcement"""

class ContextValidator(SansIOHTTPPolicy):
def on_request(self, request):
assert "enforce_https" in request.context, "'enforce_https' is not in the request's context"
return Mock()

get_token = get_completed_future(AccessToken("***", 42))
credential = Mock(get_token=lambda *_, **__: get_token)
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()]
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future()), policies=policies)
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies)

await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)


@pytest.mark.asyncio
async def test_context_unmodified_by_default():
async def test_bearer_policy_context_unmodified_by_default():
"""When no options for the policy accompany a request, the policy shouldn't add anything to the request context"""

class ContextValidator(SansIOHTTPPolicy):
def on_request(self, request):
assert not any(request.context), "the policy shouldn't add to the request's context"
return Mock()

get_token = get_completed_future(AccessToken("***", 42))
credential = Mock(get_token=lambda *_, **__: get_token)
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()]
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future()), policies=policies)
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies)

await pipeline.run(HttpRequest("GET", "https://secure"))


async def test_bearer_policy_calls_sansio_methods():
"""AsyncBearerTokenCredentialPolicy should call SansIOHttpPolicy methods as does _SansIOAsyncHTTPPolicyRunner"""

class TestPolicy(AsyncBearerTokenCredentialPolicy):
def __init__(self, *args, **kwargs):
super(TestPolicy, self).__init__(*args, **kwargs)
self.on_exception = Mock(return_value=False)
self.on_request = Mock()
self.on_response = Mock()

async def send(self, request):
self.request = request
self.response = await super(TestPolicy, self).send(request)
return self.response

credential = Mock(get_token=Mock(return_value=get_completed_future(AccessToken("***", int(time.time()) + 3600))))
policy = TestPolicy(credential, "scope")
transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200))))

pipeline = AsyncPipeline(transport=transport, policies=[policy])
await pipeline.run(HttpRequest("GET", "https://localhost"))

policy.on_request.assert_called_once_with(policy.request)
policy.on_response.assert_called_once_with(policy.request, policy.response)

# the policy should call on_exception when next.send() raises
class TestException(Exception):
pass

transport = Mock(send=Mock(side_effect=TestException))
policy = TestPolicy(credential, "scope")
pipeline = AsyncPipeline(transport=transport, policies=[policy])
with pytest.raises(TestException):
await pipeline.run(HttpRequest("GET", "https://localhost"))
policy.on_exception.assert_called_once_with(policy.request)


def get_completed_future(result=None):
fut = asyncio.Future()
fut.set_result(result)
Expand Down
Loading