diff --git a/sdk/core/azure-core/azure/core/pipeline/base.py b/sdk/core/azure-core/azure/core/pipeline/base.py index 76fb13c10e09..e0f940249a3b 100644 --- a/sdk/core/azure-core/azure/core/pipeline/base.py +++ b/sdk/core/azure-core/azure/core/pipeline/base.py @@ -37,6 +37,14 @@ PoliciesType = List[Union[HTTPPolicy, SansIOHTTPPolicy]] +def _await_result(func, *args, **kwargs): + """If func returns an awaitable, raise that this runner can't handle it.""" + result = func(*args, **kwargs) + if hasattr(result, '__await__'): + raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func)) + return result + + class _SansIOHTTPPolicyRunner(HTTPPolicy, Generic[HTTPRequestType, HTTPResponseType]): """Sync implementation of the SansIO policy. @@ -60,14 +68,14 @@ def send(self, request): :return: The PipelineResponse object. :rtype: ~azure.core.pipeline.PipelineResponse """ - self._policy.on_request(request) + _await_result(self._policy.on_request, request) try: response = self.next.send(request) except Exception: #pylint: disable=broad-except - if not self._policy.on_exception(request): + if not _await_result(self._policy.on_exception, request): raise else: - self._policy.on_response(request, response) + _await_result(self._policy.on_response, request, response) return response diff --git a/sdk/core/azure-core/azure/core/pipeline/base_async.py b/sdk/core/azure-core/azure/core/pipeline/base_async.py index 88dfeabe0681..f8eea3706ef9 100644 --- a/sdk/core/azure-core/azure/core/pipeline/base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/base_async.py @@ -49,6 +49,15 @@ async def __aexit__(self, exc_type, exc_value, traceback): return None +async def _await_result(func, *args, **kwargs): + """If func returns an awaitable, await it.""" + result = func(*args, **kwargs) + if hasattr(result, '__await__'): + # type ignore on await: https://github.com/python/mypy/issues/7587 + return await result # type: ignore + return result + + class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): #pylint: disable=unsubscriptable-object """Async implementation of the SansIO policy. @@ -62,7 +71,7 @@ def __init__(self, policy: SansIOHTTPPolicy) -> None: super(_SansIOAsyncHTTPPolicyRunner, self).__init__() self._policy = policy - async def send(self, request: PipelineRequest): + async def send(self, request: PipelineRequest) -> PipelineResponse: """Modifies the request and sends to the next policy in the chain. :param request: The PipelineRequest object. @@ -70,14 +79,14 @@ async def send(self, request: PipelineRequest): :return: The PipelineResponse object. :rtype: ~azure.core.pipeline.PipelineResponse """ - self._policy.on_request(request) + await _await_result(self._policy.on_request, request) try: response = await self.next.send(request) # type: ignore except Exception: #pylint: disable=broad-except - if not self._policy.on_exception(request): + if not await _await_result(self._policy.on_exception, request): raise else: - self._policy.on_response(request, response) + await _await_result(self._policy.on_response, request, response) return response diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/authentication.py index db7eed8359b3..77b8375f57c9 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/authentication.py @@ -5,7 +5,7 @@ # ------------------------------------------------------------------------- import time -from . import HTTPPolicy +from . import SansIOHTTPPolicy try: from typing import TYPE_CHECKING # pylint:disable=unused-import @@ -16,7 +16,7 @@ # pylint:disable=unused-import from typing import Any, Dict, Mapping, Optional from azure.core.credentials import AccessToken, TokenCredential - from azure.core.pipeline import PipelineRequest, PipelineResponse + from azure.core.pipeline import PipelineRequest # pylint:disable=too-few-public-methods @@ -51,7 +51,7 @@ def _need_new_token(self): return not self._token or self._token.expires_on - time.time() < 300 -class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy): +class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy): """Adds a bearer token Authorization header to requests. :param credential: The credential. @@ -59,16 +59,13 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy): :param str scopes: Lets you specify the type of access needed. """ - def send(self, request): - # type: (PipelineRequest) -> PipelineResponse + def on_request(self, request): + # type: (PipelineRequest) -> None """Adds a bearer token Authorization header to request and sends request to next policy. :param request: The pipeline request object :type request: ~azure.core.pipeline.PipelineRequest - :return: The pipeline response object - :rtype: ~azure.core.pipeline.PipelineResponse """ if self._need_new_token: self._token = self._credential.get_token(*self._scopes) self._update_headers(request.http_request.headers, self._token.token) # type: ignore - return self.next.send(request) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py index f6311a2abe46..d7b1a3aeac1c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/authentication_async.py @@ -5,12 +5,12 @@ # ------------------------------------------------------------------------- import threading -from azure.core.pipeline import PipelineRequest, PipelineResponse -from azure.core.pipeline.policies import AsyncHTTPPolicy +from azure.core.pipeline import PipelineRequest +from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.pipeline.policies.authentication import _BearerTokenCredentialPolicyBase -class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, AsyncHTTPPolicy): +class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy): # pylint:disable=too-few-public-methods """Adds a bearer token Authorization header to requests. @@ -23,16 +23,13 @@ def __init__(self, credential, *scopes, **kwargs): super().__init__(credential, *scopes, **kwargs) self._lock = threading.Lock() - async def send(self, request: PipelineRequest) -> PipelineResponse: + async def on_request(self, request: PipelineRequest): """Adds a bearer token Authorization header to request and sends request to next policy. :param request: The pipeline request object to be modified. :type request: ~azure.core.pipeline.PipelineRequest - :return: The pipeline response object - :rtype: ~azure.core.pipeline.PipelineResponse """ with self._lock: if self._need_new_token: self._token = await self._credential.get_token(*self._scopes) # type: ignore self._update_headers(request.http_request.headers, self._token.token) # type: ignore - return await self.next.send(request) # type: ignore diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/base.py b/sdk/core/azure-core/azure/core/pipeline/policies/base.py index 3a14a02ee92a..87b51a789b3e 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/base.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/base.py @@ -28,18 +28,30 @@ import copy import logging -from typing import (TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, # pylint: disable=unused-import - Tuple, Callable, Iterator) +from typing import ( + Generic, + TypeVar, + Union, + Any, + Dict, + Optional, +) # pylint: disable=unused-import + +try: + from typing import Awaitable # pylint: disable=unused-import +except ImportError: + pass from azure.core.pipeline import ABC, PipelineRequest, PipelineResponse + HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") _LOGGER = logging.getLogger(__name__) -class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignore +class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignore """An HTTP policy ABC. Use with a synchronous pipeline. @@ -48,6 +60,7 @@ class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignor instantiated and all policies chained. :type next: ~azure.core.pipeline.policies.HTTPPolicy or ~azure.core.pipeline.transport.HTTPTransport """ + def __init__(self): self.next = None @@ -64,6 +77,7 @@ def send(self, request): :rtype: ~azure.core.pipeline.PipelineResponse """ + class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]): """Represents a sans I/O policy. @@ -72,10 +86,12 @@ class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]): on the specifics of any particular transport. SansIOHTTPPolicy subclasses will function in either a Pipeline or an AsyncPipeline, and can act either before the request is done, or after. + You can optionally make these methods coroutines (or return awaitable objects) + but they will then be tied to AsyncPipeline usage. """ def on_request(self, request): - # type: (PipelineRequest) -> None + # type: (PipelineRequest) -> Union[None, Awaitable[None]] """Is executed before sending the request from next policy. :param request: Request to be modified before sent from next policy. @@ -83,7 +99,7 @@ def on_request(self, request): """ def on_response(self, request, response): - # type: (PipelineRequest, PipelineResponse) -> None + # type: (PipelineRequest, PipelineResponse) -> Union[None, Awaitable[None]] """Is executed after the request comes back from the policy. :param request: Request to be modified after returning from the policy. @@ -92,9 +108,9 @@ def on_response(self, request, response): :type response: ~azure.core.pipeline.PipelineResponse """ - #pylint: disable=no-self-use - def on_exception(self, _request): #pylint: disable=unused-argument - # type: (PipelineRequest) -> bool + # pylint: disable=no-self-use + def on_exception(self, _request): # pylint: disable=unused-argument + # type: (PipelineRequest) -> Union[bool, Awaitable[bool]] """Is executed if an exception is raised while executing the next policy. Developer can optionally implement this method to return True @@ -129,6 +145,7 @@ class RequestHistory(object): :param Exception error: An error encountered during the request, or None if the response was received successfully. :param dict context: The pipeline context. """ + def __init__(self, http_request, http_response=None, error=None, context=None): # type: (PipelineRequest, Optional[PipelineResponse], Exception, Optional[Dict[str, Any]]) -> None self.http_request = copy.deepcopy(http_request)