Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions sdk/core/azure-core/azure/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
PoliciesType = List[Union[HTTPPolicy, SansIOHTTPPolicy]]


def _await_result(func, *args, **kwargs):
"""If func returns an awaitable, raise that this runner can't handle it."""
result = func(*args, **kwargs)
if hasattr(result, '__await__'):
raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func))
return result


class _SansIOHTTPPolicyRunner(HTTPPolicy, Generic[HTTPRequestType, HTTPResponseType]):
"""Sync implementation of the SansIO policy.

Expand All @@ -60,14 +68,14 @@ def send(self, request):
:return: The PipelineResponse object.
:rtype: ~azure.core.pipeline.PipelineResponse
"""
self._policy.on_request(request)
_await_result(self._policy.on_request, request)
try:
response = self.next.send(request)
except Exception: #pylint: disable=broad-except
if not self._policy.on_exception(request):
if not _await_result(self._policy.on_exception, request):
raise
else:
self._policy.on_response(request, response)
_await_result(self._policy.on_response, request, response)
return response


Expand Down
17 changes: 13 additions & 4 deletions sdk/core/azure-core/azure/core/pipeline/base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ async def __aexit__(self, exc_type, exc_value, traceback):
return None


async def _await_result(func, *args, **kwargs):
"""If func returns an awaitable, await it."""
result = func(*args, **kwargs)
if hasattr(result, '__await__'):
# type ignore on await: https://github.com/python/mypy/issues/7587
return await result # type: ignore
return result


class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): #pylint: disable=unsubscriptable-object
"""Async implementation of the SansIO policy.

Expand All @@ -62,22 +71,22 @@ def __init__(self, policy: SansIOHTTPPolicy) -> None:
super(_SansIOAsyncHTTPPolicyRunner, self).__init__()
self._policy = policy

async def send(self, request: PipelineRequest):
async def send(self, request: PipelineRequest) -> PipelineResponse:
"""Modifies the request and sends to the next policy in the chain.

:param request: The PipelineRequest object.
:type request: ~azure.core.pipeline.PipelineRequest
:return: The PipelineResponse object.
:rtype: ~azure.core.pipeline.PipelineResponse
"""
self._policy.on_request(request)
await _await_result(self._policy.on_request, request)
try:
response = await self.next.send(request) # type: ignore
except Exception: #pylint: disable=broad-except
if not self._policy.on_exception(request):
if not await _await_result(self._policy.on_exception, request):
raise
else:
self._policy.on_response(request, response)
await _await_result(self._policy.on_response, request, response)
return response


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# -------------------------------------------------------------------------
import time

from . import HTTPPolicy
from . import SansIOHTTPPolicy

try:
from typing import TYPE_CHECKING # pylint:disable=unused-import
Expand All @@ -16,7 +16,7 @@
# pylint:disable=unused-import
from typing import Any, Dict, Mapping, Optional
from azure.core.credentials import AccessToken, TokenCredential
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline import PipelineRequest


# pylint:disable=too-few-public-methods
Expand Down Expand Up @@ -51,24 +51,21 @@ def _need_new_token(self):
return not self._token or self._token.expires_on - time.time() < 300


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def send(self, request):
# type: (PipelineRequest) -> PipelineResponse
def on_request(self, request):
# type: (PipelineRequest) -> None
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
if self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token) # type: ignore
return self.next.send(request)
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# -------------------------------------------------------------------------
import threading

from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import AsyncHTTPPolicy
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.policies.authentication import _BearerTokenCredentialPolicyBase


class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, AsyncHTTPPolicy):
class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
# pylint:disable=too-few-public-methods
"""Adds a bearer token Authorization header to requests.

Expand All @@ -23,16 +23,13 @@ def __init__(self, credential, *scopes, **kwargs):
super().__init__(credential, *scopes, **kwargs)
self._lock = threading.Lock()

async def send(self, request: PipelineRequest) -> PipelineResponse:
async def on_request(self, request: PipelineRequest):
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
with self._lock:
if self._need_new_token:
self._token = await self._credential.get_token(*self._scopes) # type: ignore
self._update_headers(request.http_request.headers, self._token.token) # type: ignore
return await self.next.send(request) # type: ignore
33 changes: 25 additions & 8 deletions sdk/core/azure-core/azure/core/pipeline/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,30 @@
import copy
import logging

from typing import (TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, # pylint: disable=unused-import
Tuple, Callable, Iterator)
from typing import (
Generic,
TypeVar,
Union,
Any,
Dict,
Optional,
) # pylint: disable=unused-import

try:
from typing import Awaitable # pylint: disable=unused-import
except ImportError:
pass

from azure.core.pipeline import ABC, PipelineRequest, PipelineResponse


HTTPResponseType = TypeVar("HTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")

_LOGGER = logging.getLogger(__name__)


class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignore
class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignore
"""An HTTP policy ABC.

Use with a synchronous pipeline.
Expand All @@ -48,6 +60,7 @@ class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignor
instantiated and all policies chained.
:type next: ~azure.core.pipeline.policies.HTTPPolicy or ~azure.core.pipeline.transport.HTTPTransport
"""

def __init__(self):
self.next = None

Expand All @@ -64,6 +77,7 @@ def send(self, request):
:rtype: ~azure.core.pipeline.PipelineResponse
"""


class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
"""Represents a sans I/O policy.

Expand All @@ -72,18 +86,20 @@ class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
on the specifics of any particular transport. SansIOHTTPPolicy
subclasses will function in either a Pipeline or an AsyncPipeline,
and can act either before the request is done, or after.
You can optionally make these methods coroutines (or return awaitable objects)
but they will then be tied to AsyncPipeline usage.
"""

def on_request(self, request):
# type: (PipelineRequest) -> None
# type: (PipelineRequest) -> Union[None, Awaitable[None]]
"""Is executed before sending the request from next policy.

:param request: Request to be modified before sent from next policy.
:type request: ~azure.core.pipeline.PipelineRequest
"""

def on_response(self, request, response):
# type: (PipelineRequest, PipelineResponse) -> None
# type: (PipelineRequest, PipelineResponse) -> Union[None, Awaitable[None]]
"""Is executed after the request comes back from the policy.

:param request: Request to be modified after returning from the policy.
Expand All @@ -92,9 +108,9 @@ def on_response(self, request, response):
:type response: ~azure.core.pipeline.PipelineResponse
"""

#pylint: disable=no-self-use
def on_exception(self, _request): #pylint: disable=unused-argument
# type: (PipelineRequest) -> bool
# pylint: disable=no-self-use
def on_exception(self, _request): # pylint: disable=unused-argument
# type: (PipelineRequest) -> Union[bool, Awaitable[bool]]
"""Is executed if an exception is raised while executing the next policy.

Developer can optionally implement this method to return True
Expand Down Expand Up @@ -129,6 +145,7 @@ class RequestHistory(object):
:param Exception error: An error encountered during the request, or None if the response was received successfully.
:param dict context: The pipeline context.
"""

def __init__(self, http_request, http_response=None, error=None, context=None):
# type: (PipelineRequest, Optional[PipelineResponse], Exception, Optional[Dict[str, Any]]) -> None
self.http_request = copy.deepcopy(http_request)
Expand Down