Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 20 additions & 26 deletions sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from collections.abc import Iterable
from azure.core import AsyncPipelineClient
from azure.core.pipeline.policies import (
ContentDecodePolicy,
DistributedTracingPolicy,
RequestIdPolicy,
)
from .policies import AsyncARMAutoResourceProviderRegistrationPolicy, ARMHttpLoggingPolicy


Expand All @@ -37,8 +33,14 @@ class AsyncARMPipelineClient(AsyncPipelineClient):

:param str base_url: URL for the request.
:keyword AsyncPipeline pipeline: If omitted, a Pipeline object is created and returned.
:keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport.
:keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy
:paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
:keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy
:paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
:keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for asynchronous transport.
"""

def __init__(self, base_url, **kwargs):
Expand All @@ -47,23 +49,15 @@ def __init__(self, base_url, **kwargs):
raise ValueError(
"Current implementation requires to pass 'config' if you don't pass 'policies'"
)
kwargs["policies"] = self._default_policies(**kwargs)
per_call_policies = kwargs.get('per_call_policies', [])
if isinstance(per_call_policies, Iterable):
per_call_policies.append(AsyncARMAutoResourceProviderRegistrationPolicy())
Comment on lines +53 to +54
Copy link
Member

@jiasli jiasli Apr 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

append will fail if per_call_policies is a tuple, instead of a list:

AttributeError: 'tuple' object has no attribute 'append'

else:
per_call_policies = [per_call_policies,
AsyncARMAutoResourceProviderRegistrationPolicy()]
kwargs["per_call_policies"] = per_call_policies
config = kwargs.get('config')
if not config.http_logging_policy:
config.http_logging_policy = kwargs.get('http_logging_policy', ARMHttpLoggingPolicy(**kwargs))
Comment on lines +60 to +61
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are unnecessary as the client will always have http_logging_policy configured:

self.http_logging_policy = kwargs.get('http_logging_policy') or ARMHttpLoggingPolicy(**kwargs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need it because in your test case, you use the azure.core.configure which does not honor http_logging_policy.

w/o it will fail the tests.

kwargs["config"] = config
super(AsyncARMPipelineClient, self).__init__(base_url, **kwargs)

@staticmethod
def _default_policies(config, **kwargs):
return [
RequestIdPolicy(**kwargs),
AsyncARMAutoResourceProviderRegistrationPolicy(),
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
ContentDecodePolicy(**kwargs),
config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or ARMHttpLoggingPolicy(**kwargs),
]
43 changes: 19 additions & 24 deletions sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
from azure.core import PipelineClient
from azure.core.pipeline.policies import (
ContentDecodePolicy,
DistributedTracingPolicy,
RequestIdPolicy,
)
from .policies import ARMAutoResourceProviderRegistrationPolicy, ARMHttpLoggingPolicy


Expand All @@ -38,6 +37,10 @@ class ARMPipelineClient(PipelineClient):
:param str base_url: URL for the request.
:keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned.
:keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy
:paramtype per_call_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]]
:keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy
:paramtype per_retry_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]]
:keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport.
"""

Expand All @@ -47,23 +50,15 @@ def __init__(self, base_url, **kwargs):
raise ValueError(
"Current implementation requires to pass 'config' if you don't pass 'policies'"
)
kwargs["policies"] = self._default_policies(**kwargs)
per_call_policies = kwargs.get('per_call_policies', [])
if isinstance(per_call_policies, Iterable):
per_call_policies.append(ARMAutoResourceProviderRegistrationPolicy())
else:
per_call_policies = [per_call_policies,
ARMAutoResourceProviderRegistrationPolicy()]
kwargs["per_call_policies"] = per_call_policies
config = kwargs.get('config')
if not config.http_logging_policy:
config.http_logging_policy = kwargs.get('http_logging_policy', ARMHttpLoggingPolicy(**kwargs))
kwargs["config"] = config
super(ARMPipelineClient, self).__init__(base_url, **kwargs)

@staticmethod
def _default_policies(config, **kwargs):
return [
RequestIdPolicy(**kwargs),
ARMAutoResourceProviderRegistrationPolicy(),
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
ContentDecodePolicy(**kwargs),
config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or ARMHttpLoggingPolicy(**kwargs),
]
2 changes: 1 addition & 1 deletion sdk/core/azure-mgmt-core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
'pytyped': ['py.typed'],
},
install_requires=[
"azure-core<2.0.0,>=1.9.0",
"azure-core<2.0.0,>=1.13.0",
],
extras_require={
":python_version<'3.0'": ['azure-mgmt-nspkg'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
def test_default_http_logging_policy():
config = Configuration()
pipeline_client = AsyncARMPipelineClient(base_url="test", config=config)
http_logging_policy = pipeline_client._default_policies(config=config)[-1]
http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy
assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST

def test_pass_in_http_logging_policy():
Expand All @@ -42,5 +42,5 @@ def test_pass_in_http_logging_policy():
config.http_logging_policy = http_logging_policy

pipeline_client = AsyncARMPipelineClient(base_url="test", config=config)
http_logging_policy = pipeline_client._default_policies(config=config)[-1]
http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy
assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"})
4 changes: 2 additions & 2 deletions sdk/core/azure-mgmt-core/tests/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_register_failed_policy():
def test_default_http_logging_policy():
config = Configuration()
pipeline_client = ARMPipelineClient(base_url="test", config=config)
http_logging_policy = pipeline_client._default_policies(config=config)[-1]
http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy
assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST

def test_pass_in_http_logging_policy():
Expand All @@ -183,5 +183,5 @@ def test_pass_in_http_logging_policy():
config.http_logging_policy = http_logging_policy

pipeline_client = ARMPipelineClient(base_url="test", config=config)
http_logging_policy = pipeline_client._default_policies(config=config)[-1]
http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy
assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"})
2 changes: 1 addition & 1 deletion shared_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ six>=1.11.0
isodate>=0.6.0
avro<2.0.0,>=1.10.0
#override azure azure-keyvault~=1.0
#override azure-mgmt-core azure-core<2.0.0,>=1.9.0
#override azure-mgmt-core azure-core<2.0.0,>=1.13.0
#override azure-containerregistry azure-core>=1.4.0,<2.0.0
#override azure-core-tracing-opencensus azure-core<2.0.0,>=1.0.0
#override azure-core-tracing-opentelemetry azure-core<2.0.0,>=1.13.0
Expand Down