Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions sdk/core/azure-core/azure/core/pipeline/base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,30 @@ def __init__(self, policy: SansIOHTTPPolicy) -> None:
super(_SansIOAsyncHTTPPolicyRunner, self).__init__()
self._policy = policy

async def send(self, request: PipelineRequest):
async def send(self, request: PipelineRequest) -> PipelineResponse:
"""Modifies the request and sends to the next policy in the chain.

:param request: The PipelineRequest object.
:type request: ~azure.core.pipeline.PipelineRequest
:return: The PipelineResponse object.
:rtype: ~azure.core.pipeline.PipelineResponse
"""
self._policy.on_request(request)
# type ignore on await: https://github.com/python/mypy/issues/7587
request_result = self._policy.on_request(request)
if hasattr(request_result, '__await__'):
await request_result # type: ignore
Comment thread
lmazuel marked this conversation as resolved.
Outdated
try:
response = await self.next.send(request) # type: ignore
except Exception: #pylint: disable=broad-except
if not self._policy.on_exception(request):
excp_result = self._policy.on_exception(request)
if hasattr(excp_result, '__await__'):
excp_result = await excp_result # type: ignore
if not excp_result:
raise
else:
self._policy.on_response(request, response)
resp_result = self._policy.on_response(request, response)
if hasattr(resp_result, '__await__'):
await resp_result # type: ignore
return response


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# -------------------------------------------------------------------------
import time

from . import HTTPPolicy
from . import SansIOHTTPPolicy

try:
from typing import TYPE_CHECKING # pylint:disable=unused-import
Expand All @@ -16,7 +16,7 @@
# pylint:disable=unused-import
from typing import Any, Dict, Mapping, Optional
from azure.core.credentials import AccessToken, TokenCredential
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline import PipelineRequest


# pylint:disable=too-few-public-methods
Expand Down Expand Up @@ -51,24 +51,21 @@ def _need_new_token(self):
return not self._token or self._token.expires_on - time.time() < 300


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def send(self, request):
# type: (PipelineRequest) -> PipelineResponse
def on_request(self, request):
# type: (PipelineRequest) -> None
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
if self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token) # type: ignore
return self.next.send(request)
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# -------------------------------------------------------------------------
import threading

from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import AsyncHTTPPolicy
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.policies.authentication import _BearerTokenCredentialPolicyBase


class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, AsyncHTTPPolicy):
class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
# pylint:disable=too-few-public-methods
"""Adds a bearer token Authorization header to requests.

Expand All @@ -23,16 +23,13 @@ def __init__(self, credential, *scopes, **kwargs):
super().__init__(credential, *scopes, **kwargs)
self._lock = threading.Lock()

async def send(self, request: PipelineRequest) -> PipelineResponse:
async def on_request(self, request: PipelineRequest):
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
with self._lock:
if self._need_new_token:
self._token = await self._credential.get_token(*self._scopes) # type: ignore
self._update_headers(request.http_request.headers, self._token.token) # type: ignore
return await self.next.send(request) # type: ignore
10 changes: 6 additions & 4 deletions sdk/core/azure-core/azure/core/pipeline/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import logging

from typing import (TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, # pylint: disable=unused-import
Tuple, Callable, Iterator)
Tuple, Callable, Iterator, Awaitable)

from azure.core.pipeline import ABC, PipelineRequest, PipelineResponse

Expand Down Expand Up @@ -72,18 +72,20 @@ class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
on the specifics of any particular transport. SansIOHTTPPolicy
subclasses will function in either a Pipeline or an AsyncPipeline,
and can act either before the request is done, or after.
You can optionally make these methods coroutines (or return awaitable objects)
but they will then be tied to AsyncPipeline usage.
"""

def on_request(self, request):
# type: (PipelineRequest) -> None
# type: (PipelineRequest) -> Union[None, Awaitable[None]]
"""Is executed before sending the request from next policy.

:param request: Request to be modified before sent from next policy.
:type request: ~azure.core.pipeline.PipelineRequest
"""

def on_response(self, request, response):
# type: (PipelineRequest, PipelineResponse) -> None
# type: (PipelineRequest, PipelineResponse) -> Union[None, Awaitable[None]]
"""Is executed after the request comes back from the policy.

:param request: Request to be modified after returning from the policy.
Expand All @@ -94,7 +96,7 @@ def on_response(self, request, response):

#pylint: disable=no-self-use
def on_exception(self, _request): #pylint: disable=unused-argument
# type: (PipelineRequest) -> bool
# type: (PipelineRequest) -> Union[bool, Awaitable[bool]]
"""Is executed if an exception is raised while executing the next policy.

Developer can optionally implement this method to return True
Expand Down