Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sdk/core/azure-core/azure/core/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
# --------------------------------------------------------------------------

import logging
from typing import Generic, TypeVar, List, Union, Any, Dict
from typing import cast, Generic, TypeVar, List, Union, Any, Dict
from azure.core.pipeline import (
AbstractContextManager,
PipelineRequest,
PipelineResponse,
PipelineContext,
)
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
from .policies._base import _implements_sansio_policy_protocol
from ._tools import await_result as _await_result

HTTPResponseType = TypeVar("HTTPResponseType")
Expand Down Expand Up @@ -130,10 +131,13 @@ def __init__(self, transport, policies=None):
self._transport = transport

for policy in policies or []:
if isinstance(policy, SansIOHTTPPolicy):
if callable(getattr(policy, "send", None)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess this deserves a else now that raises that it won't work. Before, we were appended the policy and praying and it would have failed at the first "pipeline.run". Now, it silently drops the object passed as policy (silent drops are dangerous)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that's reasonable. Breaks a few (hacky) tests for Key Vault (fixed by #16784), we'll see what else. Pipeline still needs to silently ignore None in policy lists because Configuration uses that as a default value for policies; I added test coverage of that.

self._impl_policies.append(cast(HTTPPolicy, policy))
elif _implements_sansio_policy_protocol(policy):
policy = cast(SansIOHTTPPolicy, policy)
self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
elif policy:
self._impl_policies.append(policy)
elif policy is not None:
raise ValueError('A Pipeline policy must implement a "send" method or the SansIOHTTPPolicy protocol')
for index in range(len(self._impl_policies) - 1):
self._impl_policies[index].next = self._impl_policies[index + 1]
if self._impl_policies:
Expand Down
12 changes: 8 additions & 4 deletions sdk/core/azure-core/azure/core/pipeline/_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
# --------------------------------------------------------------------------
import abc

from typing import Any, Union, List, Generic, TypeVar, Dict
from typing import Any, cast, Union, List, Generic, TypeVar, Dict

from azure.core.pipeline import PipelineRequest, PipelineResponse, PipelineContext
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
from .policies._base import _implements_sansio_policy_protocol
from ._tools_async import await_result as _await_result

AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
Expand Down Expand Up @@ -144,10 +145,13 @@ def __init__(self, transport, policies: AsyncPoliciesType = None) -> None:
self._transport = transport

for policy in policies or []:
if isinstance(policy, SansIOHTTPPolicy):
if callable(getattr(policy, "send", None)):
self._impl_policies.append(cast(AsyncHTTPPolicy, policy))
elif _implements_sansio_policy_protocol(policy):
policy = cast(SansIOHTTPPolicy, policy)
self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy))
elif policy:
self._impl_policies.append(policy)
elif policy is not None:
raise ValueError('A Pipeline policy must implement a "send" method or the SansIOHTTPPolicy protocol')
for index in range(len(self._impl_policies) - 1):
self._impl_policies[index].next = self._impl_policies[index + 1]
if self._impl_policies:
Expand Down
6 changes: 6 additions & 0 deletions sdk/core/azure-core/azure/core/pipeline/policies/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def send(self, request):
"""


def _implements_sansio_policy_protocol(obj):
# type: (Any) -> bool
"""Returns a bool indicating whether an object implements SansIOHTTPPolicy's methods"""
return all(callable(getattr(obj, method, None)) for method in ("on_exception", "on_request", "on_response"))


class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
"""Represents a sans I/O policy.

Expand Down
47 changes: 47 additions & 0 deletions sdk/core/azure-core/tests/async_tests/test_pipeline_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
# THE SOFTWARE.
#
#--------------------------------------------------------------------------
import asyncio
import sys
from unittest import mock

from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies import (
Expand Down Expand Up @@ -56,6 +58,51 @@
import pytest


@pytest.mark.asyncio
async def test_ignores_none_policies():
"""Pipeline shouldn't raise when a policy list includes None"""

completed_future = asyncio.Future()
completed_future.set_result(None)

transport = mock.Mock(send=mock.Mock(return_value=completed_future))
policy = mock.Mock(wraps=SansIOHTTPPolicy())
pipeline = AsyncPipeline(transport, policies=[None, policy, None])
await pipeline.run(HttpRequest("GET", "http://localhost"))

assert policy.on_request.called
assert policy.on_response.called
assert transport.send.called


@pytest.mark.asyncio
async def test_policy_wrapping():
"""AsyncPipeline should wrap only policies that implement the SansIOHTTPPolicy protocol and not send()"""

completed_future = asyncio.Future()
completed_future.set_result(None)

# this policy implements send(), so Pipeline should not wrap it with a runner
class Policy(SansIOHTTPPolicy):
def send(self, request):
return completed_future

policy = mock.MagicMock(wraps=Policy())
pipeline = AsyncPipeline(mock.Mock(), [policy])
await pipeline.run(HttpRequest("GET", "http://localhost"))
assert policy.send.call_count == 1
assert not policy.on_exception.called
assert not policy.on_request.called
assert not policy.on_response.called

policy = mock.MagicMock(wraps=SansIOHTTPPolicy())
transport = mock.Mock(send=mock.Mock(return_value=completed_future))
pipeline = AsyncPipeline(transport, [policy])
await pipeline.run(HttpRequest("GET", "http://localhost"))
assert policy.on_request.called
assert policy.on_response.called


@pytest.mark.asyncio
async def test_sans_io_exception():
class BrokenSender(AsyncHttpTransport):
Expand Down
37 changes: 37 additions & 0 deletions sdk/core/azure-core/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,43 @@

from azure.core.exceptions import AzureError


def test_ignores_none_policies():
"""Pipeline shouldn't raise when a policy list includes None"""

transport = mock.Mock()
policy = mock.Mock(wraps=SansIOHTTPPolicy())
pipeline = Pipeline(transport, policies=[None, policy, None])
pipeline.run(HttpRequest("GET", "http://localhost"))

assert policy.on_request.called
assert policy.on_response.called
assert transport.send.called


def test_policy_wrapping():
"""Pipeline should wrap only policies that implement the SansIOHTTPPolicy protocol and not send()"""

# this policy implements send(), so Pipeline should not wrap it with a runner
class Policy(SansIOHTTPPolicy):
def send(self, request):
pass

policy = mock.MagicMock(wraps=Policy())
pipeline = Pipeline(mock.Mock(), [policy])
pipeline.run(HttpRequest("GET", "http://localhost"))
assert policy.send.call_count == 1
assert not policy.on_exception.called
assert not policy.on_request.called
assert not policy.on_response.called

policy = mock.MagicMock(wraps=SansIOHTTPPolicy())
pipeline = Pipeline(mock.Mock(), [policy])
pipeline.run(HttpRequest("GET", "http://localhost"))
assert policy.on_request.called
assert policy.on_response.called


def test_default_http_logging_policy():
config = Configuration()
pipeline_client = PipelineClient(base_url="test")
Expand Down