Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import six

from . import SansIOHTTPPolicy
from . import HTTPPolicy, SansIOHTTPPolicy
from ...exceptions import ServiceRequestError

try:
Expand All @@ -18,7 +18,7 @@
# pylint:disable=unused-import
from typing import Any, Dict, Optional
from azure.core.credentials import AccessToken, TokenCredential, AzureKeyCredential, AzureSasCredential
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline import PipelineRequest, PipelineResponse


# pylint:disable=too-few-public-methods
Expand Down Expand Up @@ -71,7 +71,7 @@ def _need_new_token(self):
return not self._token or self._token.expires_on - time.time() < 300


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember we planned to modify

if isinstance(policy, SansIOHTTPPolicy):

to check hasattibute('on_request'), will that change conflict with this one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's #16726. No conflict with this change, and this change doesn't require that PR.

"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
Expand All @@ -82,17 +82,94 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPo

def on_request(self, request):
# type: (PipelineRequest) -> None
"""Adds a bearer token Authorization header to request and sends request to next policy.
"""Called before the policy sends a request.

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
The base implementation authorizes the request with a bearer token.

:param ~azure.core.pipeline.PipelineRequest request: the request
"""
self._enforce_https(request)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you think about moving _enforce_https into shared utils?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense the next time another policy needs it. Do we want the key and SAS policies to enforce HTTPS as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked this because I saw ACR also used it. :)


if self._token is None or self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)

def authorize_request(self, request, *scopes, **kwargs):
# type: (PipelineRequest, *str, **Any) -> None
"""Acquire a token from the credential and authorize the request with it.

Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
authorize future requests.

:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
self._token = self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)

def send(self, request):
# type: (PipelineRequest) -> PipelineResponse
"""Authorize request with a bearer token and send it to the next policy

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
self.on_request(request)
try:
response = self.next.send(request)
self.on_response(request, response)
except Exception: # pylint:disable=broad-except
handled = self.on_exception(request)
if not handled:
raise
else:
if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = self.on_challenge(request, response)
if request_authorized:
response = self.next.send(request)
self.on_response(request, response)

return response

def on_challenge(self, request, response):
# type: (PipelineRequest, PipelineResponse) -> bool
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
# pylint:disable=unused-argument,no-self-use
return False

def on_response(self, request, response):
# type: (PipelineRequest, PipelineResponse) -> None
"""Executed after the request comes back from the next policy.

:param request: Request to be modified after returning from the policy.
:type request: ~azure.core.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~azure.core.pipeline.PipelineResponse
"""

def on_exception(self, request):
# type: (PipelineRequest) -> bool
"""Executed when an exception is raised while executing the next policy.

This method is executed inside the exception handler.

:param request: The Pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:return: False by default, override with True to stop the exception.
:rtype: bool
"""
# pylint: disable=no-self-use,unused-argument
return False


class AzureKeyCredentialPolicy(SansIOHTTPPolicy):
"""Adds a key header for the provided credential.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,124 @@
# license information.
# -------------------------------------------------------------------------
import asyncio
import time
from typing import TYPE_CHECKING

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.policies import AsyncHTTPPolicy
from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase

from .._tools_async import await_result

class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
# pylint:disable=too-few-public-methods
if TYPE_CHECKING:
from typing import Any, Awaitable, Optional, Union
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest, PipelineResponse


class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy):
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs):
super().__init__(credential, *scopes, **kwargs)
def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None:
# pylint:disable=unused-argument
super().__init__()
self._credential = credential
self._lock = asyncio.Lock()
self._scopes = scopes
self._token = None # type: Optional[AccessToken]

async def on_request(self, request: PipelineRequest): # pylint:disable=invalid-overridden-method
async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._enforce_https(request)
_BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access

if self._token is None or self._need_new_token():
async with self._lock:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token():
self._token = await self._credential.get_token(*self._scopes)
request.http_request.headers["Authorization"] = "Bearer " + self._token.token

async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None:
"""Acquire a token from the credential and authorize the request with it.

Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
authorize future requests.

:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
async with self._lock:
if self._need_new_token:
self._token = await self._credential.get_token(*self._scopes) # type: ignore
self._update_headers(request.http_request.headers, self._token.token)
self._token = await self._credential.get_token(*scopes, **kwargs)
request.http_request.headers["Authorization"] = "Bearer " + self._token.token

async def send(self, request: "PipelineRequest") -> "PipelineResponse":
"""Authorize request with a bearer token and send it to the next policy

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
"""
await await_result(self.on_request, request)
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise
else:
if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = await self.on_challenge(request, response)
if request_authorized:
response = await self.next.send(request)
await await_result(self.on_response, request, response)

return response

async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
# pylint:disable=unused-argument,no-self-use
return False

def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> "Union[None, Awaitable[None]]":
"""Executed after the request comes back from the next policy.

:param request: Request to be modified after returning from the policy.
:type request: ~azure.core.pipeline.PipelineRequest
:param response: Pipeline response object
:type response: ~azure.core.pipeline.PipelineResponse
"""

def on_exception(self, request: "PipelineRequest") -> "Union[bool, Awaitable[bool]]":
"""Executed when an exception is raised while executing the next policy.

This method is executed inside the exception handler.

:param request: The Pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:return: False by default, override with True to stop the exception.
:rtype: bool
"""
# pylint: disable=no-self-use,unused-argument
return False

def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
Loading