diff --git a/.vscode/cspell.json b/.vscode/cspell.json index b011113169e6..7912e2bb84f5 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -396,6 +396,13 @@ "Phong" ] }, + { + "filename": "sdk/core/azure-core/tests/test_serialization.py", + "words": [ + "Rlcw", + "Jwcmlud" + ] + }, { "filename": "sdk/tables/azure-data-tables/tests/**/*.py", "words": [ diff --git a/sdk/core/azure-core-experimental/azure/__init__.py b/sdk/core/azure-core-experimental/azure/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-experimental/azure/__init__.py +++ b/sdk/core/azure-core-experimental/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-experimental/azure/core/__init__.py b/sdk/core/azure-core-experimental/azure/core/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-experimental/azure/core/__init__.py +++ b/sdk/core/azure-core-experimental/azure/core/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-experimental/azure/core/experimental/__init__.py b/sdk/core/azure-core-experimental/azure/core/experimental/__init__.py index a5db9f961fc2..58bcccc100fb 100644 --- a/sdk/core/azure-core-experimental/azure/core/experimental/__init__.py +++ b/sdk/core/azure-core-experimental/azure/core/experimental/__init__.py @@ -25,4 +25,5 @@ # -------------------------------------------------------------------------- from ._version import VERSION + __version__ = VERSION diff --git a/sdk/core/azure-core-experimental/azure/core/experimental/transport/__init__.py b/sdk/core/azure-core-experimental/azure/core/experimental/transport/__init__.py index b7d3501810c4..c031bc65e95a 100644 --- a/sdk/core/azure-core-experimental/azure/core/experimental/transport/__init__.py +++ b/sdk/core/azure-core-experimental/azure/core/experimental/transport/__init__.py @@ -29,16 +29,17 @@ if sys.version_info >= (3, 7): __all__ = [ - 'PyodideTransport', + "PyodideTransport", ] def __dir__(): return __all__ def __getattr__(name): - if name == 'PyodideTransport': + if name == "PyodideTransport": try: from ._pyodide import PyodideTransport + return PyodideTransport except ImportError: raise ImportError("pyodide package is not installed") diff --git a/sdk/core/azure-core-experimental/azure/core/experimental/transport/_pyodide.py b/sdk/core/azure-core-experimental/azure/core/experimental/transport/_pyodide.py index cb1220a29d63..b9e083f0a67e 100644 --- a/sdk/core/azure-core-experimental/azure/core/experimental/transport/_pyodide.py +++ b/sdk/core/azure-core-experimental/azure/core/experimental/transport/_pyodide.py @@ -27,7 +27,7 @@ from collections.abc import AsyncIterator from io import BytesIO -import js # pylint: disable=import-error +import js # pylint: disable=import-error from pyodide import JsException # pylint: disable=import-error from pyodide.http import pyfetch # pylint: disable=import-error @@ -42,7 +42,6 @@ class PyodideTransportResponse(AsyncHttpResponseImpl): """Async response object for the `PyodideTransport`.""" - def _js_stream(self): """So we get a fresh stream every time.""" return self._internal_response.clone().js_response.body @@ -62,6 +61,7 @@ async def load_body(self) -> None: if self._content is None: self._content = await self._internal_response.clone().bytes() + class PyodideStreamDownloadGenerator(AsyncIterator): """Simple stream download generator that returns the contents of a request. @@ -106,6 +106,7 @@ async def __anext__(self) -> bytes: self._buffer_left -= self._block_size return self._stream.read(self._block_size) + class PyodideTransport(AsyncioRequestsTransport): """**This object is experimental**, meaning it may be changed in a future release or might break with a future Pyodide release. This transport was built with Pyodide diff --git a/sdk/core/azure-core-experimental/samples/pyodide_integration/browser.py b/sdk/core/azure-core-experimental/samples/pyodide_integration/browser.py index d91762fbf0d8..e6a6a30ae07b 100644 --- a/sdk/core/azure-core-experimental/samples/pyodide_integration/browser.py +++ b/sdk/core/azure-core-experimental/samples/pyodide_integration/browser.py @@ -81,12 +81,9 @@ async def test_decompress_generator(self): data = b"".join([x async for x in response.iter_bytes()]) assert data == b"hello world\n" - async def test_sentiment_analysis(self): """Test that sentiment analysis works.""" - results = await self.text_analytics_client.analyze_sentiment( - ["good great amazing"] - ) + results = await self.text_analytics_client.analyze_sentiment(["good great amazing"]) assert len(results) == 1 result = results[0] assert result.sentiment == "positive" diff --git a/sdk/core/azure-core-experimental/setup.py b/sdk/core/azure-core-experimental/setup.py index 2dfc82062d89..b3277b94c575 100644 --- a/sdk/core/azure-core-experimental/setup.py +++ b/sdk/core/azure-core-experimental/setup.py @@ -1,10 +1,10 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import re import os.path @@ -18,50 +18,49 @@ package_folder_path = "azure/core/experimental" # Version extraction inspired from 'requests' -with open(os.path.join(package_folder_path, '_version.py'), 'r') as fd: - version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', # type: ignore - fd.read(), re.MULTILINE).group(1) +with open(os.path.join(package_folder_path, "_version.py"), "r") as fd: + version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) # type: ignore if not version: - raise RuntimeError('Cannot find version information') + raise RuntimeError("Cannot find version information") -with open('README.md', encoding='utf-8') as f: +with open("README.md", encoding="utf-8") as f: readme = f.read() -with open('CHANGELOG.md', encoding='utf-8') as f: +with open("CHANGELOG.md", encoding="utf-8") as f: changelog = f.read() setup( name=PACKAGE_NAME, version=version, - description='Microsoft Azure {} Library for Python'.format(PACKAGE_PPRINT_NAME), - long_description=readme + '\n\n' + changelog, - long_description_content_type='text/markdown', - license='MIT License', - author='Microsoft Corporation', - author_email='azpysdkhelp@microsoft.com', - url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-core-experimental', + description="Microsoft Azure {} Library for Python".format(PACKAGE_PPRINT_NAME), + long_description=readme + "\n\n" + changelog, + long_description_content_type="text/markdown", + license="MIT License", + author="Microsoft Corporation", + author_email="azpysdkhelp@microsoft.com", + url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-core-experimental", classifiers=[ "Development Status :: 4 - Beta", - 'Programming Language :: Python', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: MIT License', + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: MIT License", ], zip_safe=False, packages=[ - 'azure.core.experimental', + "azure.core.experimental", ], include_package_data=True, package_data={ - 'pytyped': ['py.typed'], + "pytyped": ["py.typed"], }, python_requires=">=3.7", install_requires=[ - 'azure-core<2.0.0,>=1.25.0', + "azure-core<2.0.0,>=1.25.0", ], ) diff --git a/sdk/core/azure-core-experimental/tests/test_pyodide_transport.py b/sdk/core/azure-core-experimental/tests/test_pyodide_transport.py index 7b84222fa908..3855c3506a38 100644 --- a/sdk/core/azure-core-experimental/tests/test_pyodide_transport.py +++ b/sdk/core/azure-core-experimental/tests/test_pyodide_transport.py @@ -78,9 +78,7 @@ def mock_pyfetch(self, mock_pyodide_module): """Utility fixture for less typing.""" return mock_pyodide_module.http.pyfetch - def create_mock_response( - self, body: bytes, headers: dict, status: int, status_text: str - ) -> mock.Mock: + def create_mock_response(self, body: bytes, headers: dict, status: int, status_text: str) -> mock.Mock: """Create a mock response object that mimics `pyodide.http.FetchResponse`""" mock_response = mock.Mock() mock_response.body = body @@ -106,9 +104,7 @@ async def test_successful_send(self, mock_pyfetch, mock_pyodide_module, pipeline method = "POST" headers = {"key": "value"} data = b"data" - request = HttpRequest( - method=method, url=PLACEHOLDER_ENDPOINT, headers=headers, data=data - ) + request = HttpRequest(method=method, url=PLACEHOLDER_ENDPOINT, headers=headers, data=data) response_body = b"0123" response_headers = {"header": "value"} response_status = 200 @@ -167,4 +163,5 @@ def test_valid_import(self, transport): """ # Use patch so we don't clutter up the `sys.modules` namespace. import azure.core.experimental.transport as transport + assert transport.PyodideTransport diff --git a/sdk/core/azure-core-tracing-opencensus/azure/__init__.py b/sdk/core/azure-core-tracing-opencensus/azure/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-tracing-opencensus/azure/__init__.py +++ b/sdk/core/azure-core-tracing-opencensus/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-tracing-opencensus/azure/core/__init__.py b/sdk/core/azure-core-tracing-opencensus/azure/core/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-tracing-opencensus/azure/core/__init__.py +++ b/sdk/core/azure-core-tracing-opencensus/azure/core/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/__init__.py b/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/__init__.py +++ b/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/ext/__init__.py b/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/ext/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/ext/__init__.py +++ b/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/ext/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/ext/opencensus_span/__init__.py b/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/ext/opencensus_span/__init__.py index 6aa9f9974582..b332ba8b7f4f 100644 --- a/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/ext/opencensus_span/__init__.py +++ b/sdk/core/azure-core-tracing-opencensus/azure/core/tracing/ext/opencensus_span/__init__.py @@ -25,6 +25,7 @@ from typing import Dict, Optional, Union, Callable, Sequence, Any from azure.core.pipeline.transport import HttpRequest, HttpResponse + AttributeValue = Union[ str, bool, @@ -40,7 +41,7 @@ __version__ = VERSION -_config_integration.trace_integrations(['threading']) +_config_integration.trace_integrations(["threading"]) class OpenCensusSpan(HttpSpanMixin, object): @@ -61,20 +62,26 @@ def __init__(self, span=None, name="span", **kwargs): :paramtype links: list[~azure.core.tracing.Link] """ tracer = self.get_current_tracer() - value = kwargs.pop('kind', None) + value = kwargs.pop("kind", None) kind = ( - OpenCensusSpanKind.CLIENT if value == SpanKind.CLIENT else - OpenCensusSpanKind.CLIENT if value == SpanKind.PRODUCER else # No producer in opencensus - OpenCensusSpanKind.SERVER if value == SpanKind.SERVER else - OpenCensusSpanKind.CLIENT if value == SpanKind.CONSUMER else # No consumer in opencensus - OpenCensusSpanKind.UNSPECIFIED if value == SpanKind.INTERNAL else # No internal in opencensus - OpenCensusSpanKind.UNSPECIFIED if value == SpanKind.UNSPECIFIED else - None - ) # type: SpanKind + OpenCensusSpanKind.CLIENT + if value == SpanKind.CLIENT + else OpenCensusSpanKind.CLIENT + if value == SpanKind.PRODUCER + else OpenCensusSpanKind.SERVER # No producer in opencensus + if value == SpanKind.SERVER + else OpenCensusSpanKind.CLIENT + if value == SpanKind.CONSUMER + else OpenCensusSpanKind.UNSPECIFIED # No consumer in opencensus + if value == SpanKind.INTERNAL + else OpenCensusSpanKind.UNSPECIFIED # No internal in opencensus + if value == SpanKind.UNSPECIFIED + else None + ) # type: SpanKind if value and kind is None: raise ValueError("Kind {} is not supported in OpenCensus".format(value)) - links = kwargs.pop('links', None) + links = kwargs.pop("links", None) self._span_instance = span or tracer.start_span(name=name, **kwargs) if kind is not None: self._span_instance.span_kind = kind @@ -84,11 +91,8 @@ def __init__(self, span=None, name="span", **kwargs): for link in links: ctx = trace_context_http_header_format.TraceContextPropagator().from_headers(link.headers) self._span_instance.add_link( - Link( - trace_id=ctx.trace_id, - span_id=ctx.span_id, - attributes=link.attributes - )) + Link(trace_id=ctx.trace_id, span_id=ctx.span_id, attributes=link.attributes) + ) except AttributeError: # we will just send the links as is if it's not ~azure.core.tracing.Link without any validation # assuming user knows what they are doing. @@ -121,10 +125,13 @@ def kind(self): """Get the span kind of this span.""" value = self.span_instance.span_kind return ( - SpanKind.CLIENT if value == OpenCensusSpanKind.CLIENT else - SpanKind.SERVER if value == OpenCensusSpanKind.SERVER else - SpanKind.UNSPECIFIED if value == OpenCensusSpanKind.UNSPECIFIED else - None + SpanKind.CLIENT + if value == OpenCensusSpanKind.CLIENT + else SpanKind.SERVER + if value == OpenCensusSpanKind.SERVER + else SpanKind.UNSPECIFIED + if value == OpenCensusSpanKind.UNSPECIFIED + else None ) _KIND_MAPPING = { @@ -171,7 +178,7 @@ def to_header(self): :return: A key value pair dictionary """ tracer_from_context = self.get_current_tracer() - temp_headers = {} # type: Dict[str, str] + temp_headers = {} # type: Dict[str, str] if tracer_from_context is not None: ctx = tracer_from_context.span_context try: @@ -205,7 +212,7 @@ def get_trace_parent(self): :return: a traceparent string :rtype: str """ - return self.to_header()['traceparent'] + return self.to_header()["traceparent"] @classmethod def link(cls, traceparent, attributes=None): @@ -216,9 +223,7 @@ def link(cls, traceparent, attributes=None): :param traceparent: A complete traceparent :type traceparent: str """ - cls.link_from_headers({ - 'traceparent': traceparent - }, attributes) + cls.link_from_headers({"traceparent": traceparent}, attributes) @classmethod def link_from_headers(cls, headers, attributes=None): @@ -231,11 +236,7 @@ def link_from_headers(cls, headers, attributes=None): """ ctx = trace_context_http_header_format.TraceContextPropagator().from_headers(headers) current_span = cls.get_current_span() - current_span.add_link(Link( - trace_id=ctx.trace_id, - span_id=ctx.span_id, - attributes=attributes - )) + current_span.add_link(Link(trace_id=ctx.trace_id, span_id=ctx.span_id, attributes=attributes)) @classmethod def get_current_span(cls): @@ -268,8 +269,7 @@ def set_current_span(cls, span): @classmethod def change_context(cls, span): # type: (Span) -> ContextManager - """Change the context for the life of this context manager. - """ + """Change the context for the life of this context manager.""" original_span = cls.get_current_span() try: execution_context.set_current_span(span) diff --git a/sdk/core/azure-core-tracing-opencensus/setup.py b/sdk/core/azure-core-tracing-opencensus/setup.py index 66cc4aa4069a..2f8ad67562c5 100644 --- a/sdk/core/azure-core-tracing-opencensus/setup.py +++ b/sdk/core/azure-core-tracing-opencensus/setup.py @@ -1,10 +1,10 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import re import os.path @@ -18,53 +18,52 @@ package_folder_path = "azure/core/tracing/ext/opencensus_span" # Version extraction inspired from 'requests' -with open(os.path.join(package_folder_path, '_version.py'), 'r') as fd: - version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', # type: ignore - fd.read(), re.MULTILINE).group(1) +with open(os.path.join(package_folder_path, "_version.py"), "r") as fd: + version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) # type: ignore if not version: - raise RuntimeError('Cannot find version information') + raise RuntimeError("Cannot find version information") -with open('README.md', encoding='utf-8') as f: +with open("README.md", encoding="utf-8") as f: readme = f.read() -with open('CHANGELOG.md', encoding='utf-8') as f: +with open("CHANGELOG.md", encoding="utf-8") as f: changelog = f.read() setup( name=PACKAGE_NAME, version=version, - description='Microsoft Azure {} Library for Python'.format(PACKAGE_PPRINT_NAME), - long_description=readme + '\n\n' + changelog, - long_description_content_type='text/markdown', - license='MIT License', - author='Microsoft Corporation', - author_email='azpysdkhelp@microsoft.com', - url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-core-tracing-opencensus', + description="Microsoft Azure {} Library for Python".format(PACKAGE_PPRINT_NAME), + long_description=readme + "\n\n" + changelog, + long_description_content_type="text/markdown", + license="MIT License", + author="Microsoft Corporation", + author_email="azpysdkhelp@microsoft.com", + url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-core-tracing-opencensus", classifiers=[ "Development Status :: 4 - Beta", - 'Programming Language :: Python', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'License :: OSI Approved :: MIT License', + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "License :: OSI Approved :: MIT License", ], zip_safe=False, packages=[ - 'azure.core.tracing.ext.opencensus_span', + "azure.core.tracing.ext.opencensus_span", ], include_package_data=True, package_data={ - 'pytyped': ['py.typed'], + "pytyped": ["py.typed"], }, python_requires=">=3.6", install_requires=[ - 'opencensus>=0.6.0', - 'opencensus-ext-azure>=0.3.1', - 'opencensus-ext-threading', - 'azure-core<2.0.0,>=1.13.0', + "opencensus>=0.6.0", + "opencensus-ext-azure>=0.3.1", + "opencensus-ext-threading", + "azure-core<2.0.0,>=1.13.0", ], ) diff --git a/sdk/core/azure-core-tracing-opencensus/tests/test_threading.py b/sdk/core/azure-core-tracing-opencensus/tests/test_threading.py index fb1c6a20bfec..30041affe53a 100644 --- a/sdk/core/azure-core-tracing-opencensus/tests/test_threading.py +++ b/sdk/core/azure-core-tracing-opencensus/tests/test_threading.py @@ -14,6 +14,7 @@ def test_get_span_from_thread(): result = [] + def get_span_from_thread(output): current_span = OpenCensusSpan.get_current_span() output.append(current_span) @@ -21,13 +22,10 @@ def get_span_from_thread(output): tracer = Tracer(sampler=AlwaysOnSampler()) with tracer.span(name="TestSpan") as span: - thread = threading.Thread( - target=get_span_from_thread, - args=(result,) - ) + thread = threading.Thread(target=get_span_from_thread, args=(result,)) thread.start() thread.join() assert span is result[0] - execution_context.clear() \ No newline at end of file + execution_context.clear() diff --git a/sdk/core/azure-core-tracing-opencensus/tests/test_tracing_implementations.py b/sdk/core/azure-core-tracing-opencensus/tests/test_tracing_implementations.py index 885f00816d0e..081db68e8b0f 100644 --- a/sdk/core/azure-core-tracing-opencensus/tests/test_tracing_implementations.py +++ b/sdk/core/azure-core-tracing-opencensus/tests/test_tracing_implementations.py @@ -23,6 +23,7 @@ import pytest + class TestOpencensusWrapper(unittest.TestCase): def test_span_passed_in(self): with ContextHelper(): @@ -141,10 +142,7 @@ def test_passing_links_in_ctor(self): trace = tracer_module.Tracer(sampler=AlwaysOnSampler()) parent = trace.start_span() wrapped_class = OpenCensusSpan( - links=[Link( - headers= {"traceparent": "00-2578531519ed94423ceae67588eff2c9-231ebdc614cb9ddd-01"} - ) - ] + links=[Link(headers={"traceparent": "00-2578531519ed94423ceae67588eff2c9-231ebdc614cb9ddd-01"})] ) assert len(wrapped_class.span_instance.links) == 1 link = wrapped_class.span_instance.links[0] @@ -157,9 +155,10 @@ def test_passing_links_in_ctor_with_attr(self): trace = tracer_module.Tracer(sampler=AlwaysOnSampler()) parent = trace.start_span() wrapped_class = OpenCensusSpan( - links=[Link( - headers= {"traceparent": "00-2578531519ed94423ceae67588eff2c9-231ebdc614cb9ddd-01"}, - attributes=attributes + links=[ + Link( + headers={"traceparent": "00-2578531519ed94423ceae67588eff2c9-231ebdc614cb9ddd-01"}, + attributes=attributes, ) ] ) @@ -169,7 +168,6 @@ def test_passing_links_in_ctor_with_attr(self): assert link.trace_id == "2578531519ed94423ceae67588eff2c9" assert link.span_id == "231ebdc614cb9ddd" - def test_set_http_attributes(self): with ContextHelper(): trace = tracer_module.Tracer(sampler=AlwaysOnSampler()) diff --git a/sdk/core/azure-core-tracing-opentelemetry/azure/__init__.py b/sdk/core/azure-core-tracing-opentelemetry/azure/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/azure/__init__.py +++ b/sdk/core/azure-core-tracing-opentelemetry/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-tracing-opentelemetry/azure/core/__init__.py b/sdk/core/azure-core-tracing-opentelemetry/azure/core/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/azure/core/__init__.py +++ b/sdk/core/azure-core-tracing-opentelemetry/azure/core/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/__init__.py b/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/__init__.py +++ b/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/ext/__init__.py b/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/ext/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/ext/__init__.py +++ b/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/ext/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/ext/opentelemetry_span/__init__.py b/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/ext/opentelemetry_span/__init__.py index f872222877ce..da1db4dd3f43 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/ext/opentelemetry_span/__init__.py +++ b/sdk/core/azure-core-tracing-opentelemetry/azure/core/tracing/ext/opentelemetry_span/__init__.py @@ -24,6 +24,7 @@ from typing import Any, Mapping, Dict, Optional, Union, Callable, Sequence from azure.core.pipeline.transport import HttpRequest, HttpResponse + AttributeValue = Union[ str, bool, @@ -56,20 +57,26 @@ def __init__(self, span=None, name="span", **kwargs): current_tracer = self.get_current_tracer() ## kind - value = kwargs.pop('kind', None) + value = kwargs.pop("kind", None) kind = ( - OpenTelemetrySpanKind.CLIENT if value == SpanKind.CLIENT else - OpenTelemetrySpanKind.PRODUCER if value == SpanKind.PRODUCER else - OpenTelemetrySpanKind.SERVER if value == SpanKind.SERVER else - OpenTelemetrySpanKind.CONSUMER if value == SpanKind.CONSUMER else - OpenTelemetrySpanKind.INTERNAL if value == SpanKind.INTERNAL else - OpenTelemetrySpanKind.INTERNAL if value == SpanKind.UNSPECIFIED else - None - ) # type: SpanKind + OpenTelemetrySpanKind.CLIENT + if value == SpanKind.CLIENT + else OpenTelemetrySpanKind.PRODUCER + if value == SpanKind.PRODUCER + else OpenTelemetrySpanKind.SERVER + if value == SpanKind.SERVER + else OpenTelemetrySpanKind.CONSUMER + if value == SpanKind.CONSUMER + else OpenTelemetrySpanKind.INTERNAL + if value == SpanKind.INTERNAL + else OpenTelemetrySpanKind.INTERNAL + if value == SpanKind.UNSPECIFIED + else None + ) # type: SpanKind if value and kind is None: raise ValueError("Kind {} is not supported in OpenTelemetry".format(value)) - links = kwargs.pop('links', None) + links = kwargs.pop("links", None) if links: try: ot_links = [] @@ -77,11 +84,11 @@ def __init__(self, span=None, name="span", **kwargs): ctx = extract(link.headers) span_ctx = get_span_from_context(ctx).get_span_context() ot_links.append(OpenTelemetryLink(span_ctx, link.attributes)) - kwargs.setdefault('links', ot_links) + kwargs.setdefault("links", ot_links) except AttributeError: # we will just send the links as is if it's not ~azure.core.tracing.Link without any validation # assuming user knows what they are doing. - kwargs.setdefault('links', links) + kwargs.setdefault("links", links) self._span_instance = span or current_tracer.start_span(name=name, kind=kind, **kwargs) self._current_ctxt_manager = None @@ -112,32 +119,42 @@ def kind(self): """Get the span kind of this span.""" value = self.span_instance.kind return ( - SpanKind.CLIENT if value == OpenTelemetrySpanKind.CLIENT else - SpanKind.PRODUCER if value == OpenTelemetrySpanKind.PRODUCER else - SpanKind.SERVER if value == OpenTelemetrySpanKind.SERVER else - SpanKind.CONSUMER if value == OpenTelemetrySpanKind.CONSUMER else - SpanKind.INTERNAL if value == OpenTelemetrySpanKind.INTERNAL else - None + SpanKind.CLIENT + if value == OpenTelemetrySpanKind.CLIENT + else SpanKind.PRODUCER + if value == OpenTelemetrySpanKind.PRODUCER + else SpanKind.SERVER + if value == OpenTelemetrySpanKind.SERVER + else SpanKind.CONSUMER + if value == OpenTelemetrySpanKind.CONSUMER + else SpanKind.INTERNAL + if value == OpenTelemetrySpanKind.INTERNAL + else None ) - @kind.setter def kind(self, value): # type: (SpanKind) -> None """Set the span kind of this span.""" kind = ( - OpenTelemetrySpanKind.CLIENT if value == SpanKind.CLIENT else - OpenTelemetrySpanKind.PRODUCER if value == SpanKind.PRODUCER else - OpenTelemetrySpanKind.SERVER if value == SpanKind.SERVER else - OpenTelemetrySpanKind.CONSUMER if value == SpanKind.CONSUMER else - OpenTelemetrySpanKind.INTERNAL if value == SpanKind.INTERNAL else - OpenTelemetrySpanKind.INTERNAL if value == SpanKind.UNSPECIFIED else - None + OpenTelemetrySpanKind.CLIENT + if value == SpanKind.CLIENT + else OpenTelemetrySpanKind.PRODUCER + if value == SpanKind.PRODUCER + else OpenTelemetrySpanKind.SERVER + if value == SpanKind.SERVER + else OpenTelemetrySpanKind.CONSUMER + if value == SpanKind.CONSUMER + else OpenTelemetrySpanKind.INTERNAL + if value == SpanKind.INTERNAL + else OpenTelemetrySpanKind.INTERNAL + if value == SpanKind.UNSPECIFIED + else None ) if kind is None: raise ValueError("Kind {} is not supported in OpenTelemetry".format(value)) try: - self._span_instance._kind = kind # pylint: disable=protected-access + self._span_instance._kind = kind # pylint: disable=protected-access except AttributeError: warnings.warn( """Kind must be set while creating the span for OpenTelemetry. It might be possible @@ -148,13 +165,13 @@ def kind(self, value): def __enter__(self): """Start a span.""" self._current_ctxt_manager = trace.use_span(self._span_instance, end_on_exit=True) - self._current_ctxt_manager.__enter__() # pylint: disable=no-member + self._current_ctxt_manager.__enter__() # pylint: disable=no-member return self def __exit__(self, exception_type, exception_value, traceback): """Finish a span.""" if self._current_ctxt_manager: - self._current_ctxt_manager.__exit__(exception_type, exception_value, traceback) # pylint: disable=no-member + self._current_ctxt_manager.__exit__(exception_type, exception_value, traceback) # pylint: disable=no-member self._current_ctxt_manager = None def start(self): @@ -166,13 +183,13 @@ def finish(self): """Set the end time for a span.""" self.span_instance.end() - def to_header(self): # pylint: disable=no-self-use + def to_header(self): # pylint: disable=no-self-use # type: () -> Dict[str, str] """ Returns a dictionary with the header labels and values. :return: A key value pair dictionary """ - temp_headers = {} # type: Dict[str, str] + temp_headers = {} # type: Dict[str, str] inject(temp_headers) return temp_headers @@ -202,7 +219,7 @@ def get_trace_parent(self): :return: a traceparent string :rtype: str """ - return self.to_header()['traceparent'] + return self.to_header()["traceparent"] @classmethod def link(cls, traceparent, attributes=None): @@ -213,9 +230,7 @@ def link(cls, traceparent, attributes=None): :param traceparent: A complete traceparent :type traceparent: str """ - cls.link_from_headers({ - 'traceparent': traceparent - }, attributes) + cls.link_from_headers({"traceparent": traceparent}, attributes) @classmethod def link_from_headers(cls, headers, attributes=None): @@ -230,7 +245,7 @@ def link_from_headers(cls, headers, attributes=None): span_ctx = get_span_from_context(ctx).get_span_context() current_span = cls.get_current_span() try: - current_span._links.append(OpenTelemetryLink(span_ctx, attributes)) # pylint: disable=protected-access + current_span._links.append(OpenTelemetryLink(span_ctx, attributes)) # pylint: disable=protected-access except AttributeError: warnings.warn( """Link must be added while creating the span for OpenTelemetry. It might be possible @@ -257,15 +272,13 @@ def get_current_tracer(cls): @classmethod def change_context(cls, span): # type: (Span) -> ContextManager - """Change the context for the life of this context manager. - """ + """Change the context for the life of this context manager.""" return trace.use_span(span, end_on_exit=False) @classmethod def set_current_span(cls, span): # type: (Span) -> None - """Not supported by OpenTelemetry. - """ + """Not supported by OpenTelemetry.""" raise NotImplementedError( "set_current_span is not supported by OpenTelemetry plugin. Use change_context instead." ) diff --git a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_eventgrid.py b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_eventgrid.py index 8297190acd27..b31cbfe6e726 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_eventgrid.py +++ b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_eventgrid.py @@ -35,9 +35,7 @@ trace.set_tracer_provider(TracerProvider()) tracer = trace.get_tracer(__name__) -trace.get_tracer_provider().add_span_processor( - SimpleSpanProcessor(exporter) -) +trace.get_tracer_provider().add_span_processor(SimpleSpanProcessor(exporter)) # Example with Eventgrid SDKs import os @@ -45,14 +43,9 @@ from azure.eventgrid import EventGridPublisherClient from azure.core.credentials import AzureKeyCredential -hostname = os.environ['CLOUD_TOPIC_HOSTNAME'] -key = AzureKeyCredential(os.environ['CLOUD_ACCESS_KEY']) -cloud_event = CloudEvent( - source = 'demo', - type = 'sdk.demo', - data = {'test': 'hello'}, - extensions = {'test': 'maybe'} -) +hostname = os.environ["CLOUD_TOPIC_HOSTNAME"] +key = AzureKeyCredential(os.environ["CLOUD_ACCESS_KEY"]) +cloud_event = CloudEvent(source="demo", type="sdk.demo", data={"test": "hello"}, extensions={"test": "maybe"}) with tracer.start_as_current_span(name="MyApplication"): client = EventGridPublisherClient(hostname, key) client.send(cloud_event) diff --git a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_eventhubs.py b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_eventhubs.py index 1130f0703013..e4db96f1c3ee 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_eventhubs.py +++ b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_eventhubs.py @@ -34,29 +34,29 @@ trace.set_tracer_provider(TracerProvider()) tracer = trace.get_tracer(__name__) -trace.get_tracer_provider().add_span_processor( - SimpleSpanProcessor(exporter) -) +trace.get_tracer_provider().add_span_processor(SimpleSpanProcessor(exporter)) from azure.eventhub import EventHubProducerClient, EventData import os -FULLY_QUALIFIED_NAMESPACE = os.environ['EVENT_HUB_HOSTNAME'] -EVENTHUB_NAME = os.environ['EVENT_HUB_NAME'] +FULLY_QUALIFIED_NAMESPACE = os.environ["EVENT_HUB_HOSTNAME"] +EVENTHUB_NAME = os.environ["EVENT_HUB_NAME"] + +credential = os.environ["EVENTHUB_CONN_STR"] -credential = os.environ['EVENTHUB_CONN_STR'] def on_event(context, event): print(context.partition_id, ":", event) + with tracer.start_as_current_span(name="MyApplication"): producer_client = EventHubProducerClient.from_connection_string( conn_str=credential, fully_qualified_namespace=FULLY_QUALIFIED_NAMESPACE, eventhub_name=EVENTHUB_NAME, - logging_enable=True + logging_enable=True, ) with producer_client: event_data_batch = producer_client.create_batch() - event_data_batch.add(EventData('Single message')) + event_data_batch.add(EventData("Single message")) producer_client.send_batch(event_data_batch) diff --git a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_receive_eh.py b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_receive_eh.py index 8ace53558f24..7661bd819e9a 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_receive_eh.py +++ b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_receive_eh.py @@ -34,23 +34,23 @@ trace.set_tracer_provider(TracerProvider()) tracer = trace.get_tracer(__name__) -trace.get_tracer_provider().add_span_processor( - SimpleSpanProcessor(exporter) -) +trace.get_tracer_provider().add_span_processor(SimpleSpanProcessor(exporter)) from azure.eventhub import EventHubProducerClient, EventData, EventHubConsumerClient import os -FULLY_QUALIFIED_NAMESPACE = os.environ['EVENT_HUB_HOSTNAME'] -EVENTHUB_NAME = os.environ['EVENT_HUB_NAME'] +FULLY_QUALIFIED_NAMESPACE = os.environ["EVENT_HUB_HOSTNAME"] +EVENTHUB_NAME = os.environ["EVENT_HUB_NAME"] + +credential = os.environ["EVENTHUB_CONN_STR"] -credential = os.environ['EVENTHUB_CONN_STR'] def on_event(partition_context, event): # Put your code here. # If the operation is i/o intensive, multi-thread will have better performance. print("Received event from partition: {}.".format(partition_context.partition_id)) + def on_partition_initialize(partition_context): # Put your code here. print("Partition: {} has been initialized.".format(partition_context.partition_id)) @@ -58,26 +58,25 @@ def on_partition_initialize(partition_context): def on_partition_close(partition_context, reason): # Put your code here. - print("Partition: {} has been closed, reason for closing: {}.".format( - partition_context.partition_id, - reason - )) + print("Partition: {} has been closed, reason for closing: {}.".format(partition_context.partition_id, reason)) def on_error(partition_context, error): # Put your code here. partition_context can be None in the on_error callback. if partition_context: - print("An exception: {} occurred during receiving from Partition: {}.".format( - partition_context.partition_id, - error - )) + print( + "An exception: {} occurred during receiving from Partition: {}.".format( + partition_context.partition_id, error + ) + ) else: print("An exception: {} occurred during the load balance process.".format(error)) + with tracer.start_as_current_span(name="MyApplication"): consumer_client = EventHubConsumerClient.from_connection_string( conn_str=credential, - consumer_group='$Default', + consumer_group="$Default", eventhub_name=EVENTHUB_NAME, ) @@ -91,5 +90,4 @@ def on_error(partition_context, error): starting_position="-1", # "-1" is from the beginning of the partition. ) except KeyboardInterrupt: - print('Stopped receiving.') - + print("Stopped receiving.") diff --git a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_receive_sb.py b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_receive_sb.py index 5a7327235f93..79b0e988506d 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_receive_sb.py +++ b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_receive_sb.py @@ -30,21 +30,19 @@ # Simple console exporter exporter = ConsoleSpanExporter() -span_processor = SimpleSpanProcessor( - exporter -) +span_processor = SimpleSpanProcessor(exporter) trace.get_tracer_provider().add_span_processor(span_processor) # Example with Servicebus SDKs from azure.servicebus import ServiceBusClient, ServiceBusMessage -connstr = os.environ['SERVICE_BUS_CONN_STR'] -queue_name = os.environ['SERVICE_BUS_QUEUE_NAME'] +connstr = os.environ["SERVICE_BUS_CONN_STR"] +queue_name = os.environ["SERVICE_BUS_QUEUE_NAME"] with tracer.start_as_current_span(name="MyApplication2"): with ServiceBusClient.from_connection_string(connstr) as client: with client.get_queue_sender(queue_name) as sender: - #Sending a single message + # Sending a single message single_message = ServiceBusMessage("Single message") sender.send_messages(single_message) # continually receives new messages until it doesn't receive any new messages for 5 (max_wait_time) seconds. diff --git a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_servicebus.py b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_servicebus.py index 558527a06aea..84a1bb50d6a5 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_servicebus.py +++ b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_servicebus.py @@ -34,17 +34,15 @@ trace.set_tracer_provider(TracerProvider()) tracer = trace.get_tracer(__name__) -trace.get_tracer_provider().add_span_processor( - SimpleSpanProcessor(exporter) -) +trace.get_tracer_provider().add_span_processor(SimpleSpanProcessor(exporter)) # Example with Servicebus SDKs from azure.servicebus import ServiceBusClient, ServiceBusMessage import os -connstr = os.environ['SERVICE_BUS_CONN_STR'] -queue_name = os.environ['SERVICE_BUS_QUEUE_NAME'] +connstr = os.environ["SERVICE_BUS_CONN_STR"] +queue_name = os.environ["SERVICE_BUS_QUEUE_NAME"] with tracer.start_as_current_span(name="MyApplication"): diff --git a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_storage.py b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_storage.py index 7cd50751b14a..39655accd674 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/samples/sample_storage.py +++ b/sdk/core/azure-core-tracing-opentelemetry/samples/sample_storage.py @@ -35,16 +35,14 @@ trace.set_tracer_provider(TracerProvider()) tracer = trace.get_tracer(__name__) -trace.get_tracer_provider().add_span_processor( - SimpleSpanProcessor(exporter) -) +trace.get_tracer_provider().add_span_processor(SimpleSpanProcessor(exporter)) # Example with Storage SDKs import os from azure.storage.blob import BlobServiceClient -connection_string = os.environ['AZURE_STORAGE_CONNECTION_STRING'] -container_name = os.environ['AZURE_STORAGE_BLOB_CONTAINER_NAME'] +connection_string = os.environ["AZURE_STORAGE_CONNECTION_STRING"] +container_name = os.environ["AZURE_STORAGE_BLOB_CONTAINER_NAME"] with tracer.start_as_current_span(name="MyApplication"): client = BlobServiceClient.from_connection_string(connection_string) diff --git a/sdk/core/azure-core-tracing-opentelemetry/setup.py b/sdk/core/azure-core-tracing-opentelemetry/setup.py index 003aa6d5badf..8293cebe9faa 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/setup.py +++ b/sdk/core/azure-core-tracing-opentelemetry/setup.py @@ -1,10 +1,10 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import re import os.path @@ -18,51 +18,50 @@ package_folder_path = "azure/core/tracing/ext/opentelemetry_span" # Version extraction inspired from 'requests' -with open(os.path.join(package_folder_path, '_version.py'), 'r') as fd: - version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', # type: ignore - fd.read(), re.MULTILINE).group(1) +with open(os.path.join(package_folder_path, "_version.py"), "r") as fd: + version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) # type: ignore if not version: - raise RuntimeError('Cannot find version information') + raise RuntimeError("Cannot find version information") -with open('README.md', encoding='utf-8') as f: +with open("README.md", encoding="utf-8") as f: readme = f.read() -with open('CHANGELOG.md', encoding='utf-8') as f: +with open("CHANGELOG.md", encoding="utf-8") as f: changelog = f.read() setup( name=PACKAGE_NAME, version=version, - description='Microsoft Azure {} Library for Python'.format(PACKAGE_PPRINT_NAME), - long_description=readme + '\n\n' + changelog, - long_description_content_type='text/markdown', - license='MIT License', - author='Microsoft Corporation', - author_email='azpysdkhelp@microsoft.com', - url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-core-tracing-opentelemetry', + description="Microsoft Azure {} Library for Python".format(PACKAGE_PPRINT_NAME), + long_description=readme + "\n\n" + changelog, + long_description_content_type="text/markdown", + license="MIT License", + author="Microsoft Corporation", + author_email="azpysdkhelp@microsoft.com", + url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-core-tracing-opentelemetry", classifiers=[ "Development Status :: 4 - Beta", - 'Programming Language :: Python', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'License :: OSI Approved :: MIT License', + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "License :: OSI Approved :: MIT License", ], zip_safe=False, packages=[ - 'azure.core.tracing.ext.opentelemetry_span', + "azure.core.tracing.ext.opentelemetry_span", ], include_package_data=True, package_data={ - 'pytyped': ['py.typed'], + "pytyped": ["py.typed"], }, python_requires=">=3.6", install_requires=[ - 'opentelemetry-api<2.0.0,>=1.0.0', - 'azure-core<2.0.0,>=1.13.0', + "opentelemetry-api<2.0.0,>=1.0.0", + "azure-core<2.0.0,>=1.13.0", ], ) diff --git a/sdk/core/azure-core-tracing-opentelemetry/tests/test_threading.py b/sdk/core/azure-core-tracing-opentelemetry/tests/test_threading.py index ffdea46be91f..19b6c561a872 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/tests/test_threading.py +++ b/sdk/core/azure-core-tracing-opentelemetry/tests/test_threading.py @@ -10,16 +10,14 @@ def test_get_span_from_thread(tracer): result = [] + def get_span_from_thread(output): current_span = OpenTelemetrySpan.get_current_span() output.append(current_span) with tracer.start_as_current_span(name="TestSpan") as span: - thread = threading.Thread( - target=OpenTelemetrySpan.with_current_context(get_span_from_thread), - args=(result,) - ) + thread = threading.Thread(target=OpenTelemetrySpan.with_current_context(get_span_from_thread), args=(result,)) thread.start() thread.join() diff --git a/sdk/core/azure-core-tracing-opentelemetry/tests/test_tracing_implementations.py b/sdk/core/azure-core-tracing-opentelemetry/tests/test_tracing_implementations.py index 769207087099..df6cf33bb7d7 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/tests/test_tracing_implementations.py +++ b/sdk/core/azure-core-tracing-opentelemetry/tests/test_tracing_implementations.py @@ -41,7 +41,6 @@ def test_no_span_passed_in_with_no_environ(self, tracer): assert parent is trace.get_current_span() - def test_span(self, tracer): with tracer.start_as_current_span("Root") as parent: with OpenTelemetrySpan() as wrapped_span: @@ -188,4 +187,4 @@ def test_span_kind(self, tracer): assert wrapped_class.kind == SpanKind.INTERNAL with pytest.raises(ValueError): - wrapped_class.kind = "somethingstuid" \ No newline at end of file + wrapped_class.kind = "somethingstuid" diff --git a/sdk/core/azure-core/azure/core/_enum_meta.py b/sdk/core/azure-core/azure/core/_enum_meta.py index 3015ce3faf24..1d821573115a 100644 --- a/sdk/core/azure-core/azure/core/_enum_meta.py +++ b/sdk/core/azure-core/azure/core/_enum_meta.py @@ -45,9 +45,7 @@ class MyCustomEnum(str, Enum, metaclass=CaseInsensitiveEnumMeta): def __getitem__(cls, name: str) -> Any: # disabling pylint bc of pylint bug https://github.com/PyCQA/astroid/issues/713 - return super( # pylint: disable=no-value-for-parameter - CaseInsensitiveEnumMeta, cls - ).__getitem__(name.upper()) + return super(CaseInsensitiveEnumMeta, cls).__getitem__(name.upper()) # pylint: disable=no-value-for-parameter def __getattr__(cls, name: str) -> Enum: """Return the enum member matching `name` diff --git a/sdk/core/azure-core/azure/core/_pipeline_client.py b/sdk/core/azure-core/azure/core/_pipeline_client.py index 6460eb0a04f4..c3773924f873 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client.py @@ -152,8 +152,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use index_of_retry = index if index_of_retry == -1: raise ValueError( - "Failed to add per_retry_policies; " - "no RetryPolicy found in the supplied list of policies. " + "Failed to add per_retry_policies; no RetryPolicy found in the supplied list of policies. " ) policies_1 = policies[: index_of_retry + 1] policies_2 = policies[index_of_retry + 1 :] @@ -185,9 +184,7 @@ def send_request(self, request: "HTTPRequestType", **kwargs) -> "HTTPResponseTyp """ stream = kwargs.pop("stream", False) # want to add default value return_pipeline_response = kwargs.pop("_return_pipeline_response", False) - pipeline_response = self._pipeline.run( - request, stream=stream, **kwargs - ) # pylint: disable=protected-access + pipeline_response = self._pipeline.run(request, stream=stream, **kwargs) # pylint: disable=protected-access if return_pipeline_response: return pipeline_response return pipeline_response.http_response diff --git a/sdk/core/azure-core/azure/core/_pipeline_client_async.py b/sdk/core/azure-core/azure/core/_pipeline_client_async.py index 8207e4d4216d..7008f8fa119e 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client_async.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client_async.py @@ -58,9 +58,7 @@ async def close(self): HTTPRequestType = TypeVar("HTTPRequestType") -AsyncHTTPResponseType = TypeVar( - "AsyncHTTPResponseType", bound="_AsyncContextManagerCloseable" -) +AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="_AsyncContextManagerCloseable") _LOGGER = logging.getLogger(__name__) @@ -126,9 +124,7 @@ async def close(self) -> None: await self._response.close() -class AsyncPipelineClient( - PipelineClientBase, AsyncContextManager["AsyncPipelineClient"] -): +class AsyncPipelineClient(PipelineClientBase, AsyncContextManager["AsyncPipelineClient"]): """Service client core methods. Builds an AsyncPipeline client. @@ -233,8 +229,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use index_of_retry = index if index_of_retry == -1: raise ValueError( - "Failed to add per_retry_policies; " - "no RetryPolicy found in the supplied list of policies. " + "Failed to add per_retry_policies; no RetryPolicy found in the supplied list of policies. " ) policies_1 = policies[: index_of_retry + 1] policies_2 = policies[index_of_retry + 1 :] @@ -251,9 +246,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use async def _make_pipeline_call(self, request, **kwargs): return_pipeline_response = kwargs.pop("_return_pipeline_response", False) - pipeline_response = await self._pipeline.run( - request, **kwargs # pylint: disable=protected-access - ) + pipeline_response = await self._pipeline.run(request, **kwargs) # pylint: disable=protected-access if return_pipeline_response: return pipeline_response return pipeline_response.http_response diff --git a/sdk/core/azure-core/azure/core/async_paging.py b/sdk/core/azure-core/azure/core/async_paging.py index 62f7eda3ae6f..3b3ef9d54131 100644 --- a/sdk/core/azure-core/azure/core/async_paging.py +++ b/sdk/core/azure-core/azure/core/async_paging.py @@ -70,9 +70,7 @@ class AsyncPageIterator(AsyncIterator[AsyncIterator[ReturnType]]): def __init__( self, get_next: Callable[[Optional[str]], Awaitable[ResponseType]], - extract_data: Callable[ - [ResponseType], Awaitable[Tuple[str, AsyncIterator[ReturnType]]] - ], + extract_data: Callable[[ResponseType], Awaitable[Tuple[str, AsyncIterator[ReturnType]]]], continuation_token: Optional[str] = None, ) -> None: """Return an async iterator of pages. @@ -101,9 +99,7 @@ async def __anext__(self) -> AsyncIterator[ReturnType]: self._did_a_call_already = True - self.continuation_token, self._current_page = await self._extract_data( - self._response - ) + self.continuation_token, self._current_page = await self._extract_data(self._response) # If current_page was a sync list, wrap it async-like if isinstance(self._current_page, collections.abc.Iterable): @@ -123,9 +119,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._kwargs = kwargs self._page_iterator: Optional[AsyncIterator[AsyncIterator[ReturnType]]] = None self._page: Optional[AsyncIterator[ReturnType]] = None - self._page_iterator_class = self._kwargs.pop( - "page_iterator_class", AsyncPageIterator - ) + self._page_iterator_class = self._kwargs.pop("page_iterator_class", AsyncPageIterator) def by_page( self, @@ -139,9 +133,7 @@ def by_page( this generator will begin returning results from this point. :returns: An async iterator of pages (themselves async iterator of objects) """ - return self._page_iterator_class( - *self._args, **self._kwargs, continuation_token=continuation_token - ) + return self._page_iterator_class(*self._args, **self._kwargs, continuation_token=continuation_token) async def __anext__(self) -> ReturnType: if self._page_iterator is None: diff --git a/sdk/core/azure-core/azure/core/credentials.py b/sdk/core/azure-core/azure/core/credentials.py index dbf09d7d2e8f..cb06e7d4cdf8 100644 --- a/sdk/core/azure-core/azure/core/credentials.py +++ b/sdk/core/azure-core/azure/core/credentials.py @@ -24,11 +24,7 @@ class TokenCredential(Protocol): """Protocol for classes able to provide OAuth tokens.""" def get_token( - self, - *scopes: str, - claims: Optional[str] = None, - tenant_id: Optional[str] = None, - **kwargs: Any + self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any ) -> AccessToken: """Request an access token for `scopes`. diff --git a/sdk/core/azure-core/azure/core/credentials_async.py b/sdk/core/azure-core/azure/core/credentials_async.py index 378dd91d8b28..0c02cc37281f 100644 --- a/sdk/core/azure-core/azure/core/credentials_async.py +++ b/sdk/core/azure-core/azure/core/credentials_async.py @@ -12,11 +12,7 @@ class AsyncTokenCredential(Protocol): """Protocol for classes able to provide OAuth tokens.""" async def get_token( - self, - *scopes: str, - claims: Optional[str] = None, - tenant_id: Optional[str] = None, - **kwargs: Any + self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any ) -> _AccessToken: """Request an access token for `scopes`. diff --git a/sdk/core/azure-core/azure/core/exceptions.py b/sdk/core/azure-core/azure/core/exceptions.py index 177c6070ab4c..8c702a7580dc 100644 --- a/sdk/core/azure-core/azure/core/exceptions.py +++ b/sdk/core/azure-core/azure/core/exceptions.py @@ -164,10 +164,7 @@ def __init__(self, json_object: Dict[str, Any]): self.message: Optional[str] = json_object.get(cls.MESSAGE_LABEL) if not (self.code or self.message): - raise ValueError( - "Impossible to extract code/message from received JSON:\n" - + json.dumps(json_object) - ) + raise ValueError("Impossible to extract code/message from received JSON:\n" + json.dumps(json_object)) # Optional fields self.target: Optional[str] = json_object.get(cls.TARGET_LABEL) @@ -209,9 +206,7 @@ def message_details(self) -> str: error_str += "\n".join("\t" + s for s in str(error_obj).splitlines()) if self.innererror: - error_str += "\nInner error: {}".format( - json.dumps(self.innererror, indent=4) - ) + error_str += "\nInner error: {}".format(json.dumps(self.innererror, indent=4)) return error_str @@ -236,9 +231,7 @@ class AzureError(Exception): def __init__(self, message, *args, **kwargs): self.inner_exception = kwargs.get("error") self.exc_type, self.exc_value, self.exc_traceback = sys.exc_info() - self.exc_type = ( - self.exc_type.__name__ if self.exc_type else type(self.inner_exception) - ) + self.exc_type = self.exc_type.__name__ if self.exc_type else type(self.inner_exception) self.exc_msg = "{}, {}: {}".format(message, self.exc_type, self.exc_value) self.message = str(message) self.continuation_token = kwargs.get("continuation_token") @@ -310,9 +303,7 @@ def __init__(self, message=None, response=None, **kwargs): self.model = model else: # autorest azure-core, for KV 1.0, Storage 12.0, etc. self.model: Optional[Any] = getattr(self, "error", None) - self.error: Optional[ODataV4Format] = self._parse_odata_body( - error_format, response - ) + self.error: Optional[ODataV4Format] = self._parse_odata_body(error_format, response) # By priority, message is: # - odatav4 message, OR @@ -321,16 +312,12 @@ def __init__(self, message=None, response=None, **kwargs): if self.error: message = str(self.error) else: - message = message or "Operation returned an invalid status '{}'".format( - self.reason - ) + message = message or "Operation returned an invalid status '{}'".format(self.reason) super(HttpResponseError, self).__init__(message=message, **kwargs) @staticmethod - def _parse_odata_body( - error_format: Type[ODataV4Format], response: "_HttpResponseBase" - ) -> Optional[ODataV4Format]: + def _parse_odata_body(error_format: Type[ODataV4Format], response: "_HttpResponseBase") -> Optional[ODataV4Format]: try: odata_json = json.loads(response.text()) return error_format(odata_json) @@ -436,18 +423,10 @@ def __init__(self, response: "_HttpResponseBase", **kwargs) -> None: try: error_node = self.odata_json["error"] self._error_format = self._ERROR_FORMAT(error_node) - self.__dict__.update( - { - k: v - for k, v in self._error_format.__dict__.items() - if v is not None - } - ) + self.__dict__.update({k: v for k, v in self._error_format.__dict__.items() if v is not None}) except Exception: # pylint: disable=broad-except _LOGGER.info("Received error message was not valid OdataV4 format.") - self._error_format = "JSON was invalid for format " + str( - self._ERROR_FORMAT - ) + self._error_format = "JSON was invalid for format " + str(self._ERROR_FORMAT) def __str__(self): if self._error_format: @@ -465,9 +444,7 @@ class StreamConsumedError(AzureError): def __init__(self, response): message = ( "You are attempting to read or stream the content from request {}. " - "You have likely already consumed this stream, so it can not be accessed anymore.".format( - response.request - ) + "You have likely already consumed this stream, so it can not be accessed anymore.".format(response.request) ) super(StreamConsumedError, self).__init__(message) diff --git a/sdk/core/azure-core/azure/core/messaging.py b/sdk/core/azure-core/azure/core/messaging.py index 5140558335e7..414447dff11f 100644 --- a/sdk/core/azure-core/azure/core/messaging.py +++ b/sdk/core/azure-core/azure/core/messaging.py @@ -124,9 +124,7 @@ def __init__( if self.extensions: for key in self.extensions.keys(): if not key.islower() or not key.isalnum(): - raise ValueError( - "Extension attributes should be lower cased and alphanumeric." - ) + raise ValueError("Extension attributes should be lower cased and alphanumeric.") if kwargs: remaining = ", ".join(kwargs.keys()) @@ -163,17 +161,13 @@ def from_dict(cls, event: Dict[str, Any]) -> "CloudEvent": ] if "data" in event and "data_base64" in event: - raise ValueError( - "Invalid input. Only one of data and data_base64 must be present." - ) + raise ValueError("Invalid input. Only one of data and data_base64 must be present.") if "data" in event: data = event.get("data") kwargs["data"] = data if data is not None else NULL elif "data_base64" in event: - kwargs["data"] = b64decode( - cast(Union[str, bytes], event.get("data_base64")) - ) + kwargs["data"] = b64decode(cast(Union[str, bytes], event.get("data_base64"))) for item in ["datacontenttype", "dataschema", "subject"]: if item in event: diff --git a/sdk/core/azure-core/azure/core/paging.py b/sdk/core/azure-core/azure/core/paging.py index f0ee547039a3..beed33671674 100644 --- a/sdk/core/azure-core/azure/core/paging.py +++ b/sdk/core/azure-core/azure/core/paging.py @@ -98,13 +98,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._args = args self._kwargs = kwargs self._page_iterator: Optional[Iterator[ReturnType]] = None - self._page_iterator_class = self._kwargs.pop( - "page_iterator_class", PageIterator - ) + self._page_iterator_class = self._kwargs.pop("page_iterator_class", PageIterator) - def by_page( - self, continuation_token: Optional[str] = None - ) -> Iterator[Iterator[ReturnType]]: + def by_page(self, continuation_token: Optional[str] = None) -> Iterator[Iterator[ReturnType]]: """Get an iterator of pages of objects, instead of an iterator of objects. :param str continuation_token: @@ -113,14 +109,10 @@ def by_page( this generator will begin returning results from this point. :returns: An iterator of pages (themselves iterator of objects) """ - return self._page_iterator_class( - continuation_token=continuation_token, *self._args, **self._kwargs - ) + return self._page_iterator_class(continuation_token=continuation_token, *self._args, **self._kwargs) def __repr__(self) -> str: - return "".format( - hex(id(self)) - ) + return "".format(hex(id(self))) def __iter__(self) -> Iterator[ReturnType]: """Return 'self'.""" diff --git a/sdk/core/azure-core/azure/core/pipeline/_base.py b/sdk/core/azure-core/azure/core/pipeline/_base.py index ecb62e9677ed..483d3047dccf 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/_base.py @@ -121,9 +121,7 @@ class Pipeline(AbstractContextManager, Generic[HTTPRequestType, HTTPResponseType :caption: Builds the pipeline for synchronous transport. """ - def __init__( - self, transport: HttpTransportType, policies: Optional[PoliciesType] = None - ) -> None: + def __init__(self, transport: HttpTransportType, policies: Optional[PoliciesType] = None) -> None: self._impl_policies: List[HTTPPolicy] = [] self._transport = transport @@ -194,12 +192,6 @@ def run(self, request: HTTPRequestType, **kwargs: Any) -> PipelineResponse: """ self._prepare_multipart(request) context = PipelineContext(self._transport, **kwargs) - pipeline_request: PipelineRequest[HTTPRequestType] = PipelineRequest( - request, context - ) - first_node = ( - self._impl_policies[0] - if self._impl_policies - else _TransportRunner(self._transport) - ) + pipeline_request: PipelineRequest[HTTPRequestType] = PipelineRequest(request, context) + first_node = self._impl_policies[0] if self._impl_policies else _TransportRunner(self._transport) return first_node.send(pipeline_request) # type: ignore diff --git a/sdk/core/azure-core/azure/core/pipeline/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/_base_async.py index 7cbd458da942..6584ec82a875 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_base_async.py @@ -33,9 +33,7 @@ AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") ImplPoliciesType = List[ - AsyncHTTPPolicy[ # pylint: disable=unsubscriptable-object - HTTPRequestType, AsyncHTTPResponseType - ] + AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType] # pylint: disable=unsubscriptable-object ] AsyncPoliciesType = List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]] @@ -103,9 +101,7 @@ async def send(self, request): ) -class AsyncPipeline( - AbstractAsyncContextManager, Generic[HTTPRequestType, AsyncHTTPResponseType] -): +class AsyncPipeline(AbstractAsyncContextManager, Generic[HTTPRequestType, AsyncHTTPResponseType]): """Async pipeline implementation. This is implemented as a context manager, that will activate the context @@ -192,9 +188,5 @@ async def run(self, request: HTTPRequestType, **kwargs: Any): await self._prepare_multipart(request) context = PipelineContext(self._transport, **kwargs) pipeline_request = PipelineRequest(request, context) - first_node = ( - self._impl_policies[0] - if self._impl_policies - else _AsyncTransportRunner(self._transport) - ) + first_node = self._impl_policies[0] if self._impl_policies else _AsyncTransportRunner(self._transport) return await first_node.send(pipeline_request) diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools.py b/sdk/core/azure-core/azure/core/pipeline/_tools.py index c19460f73a2c..a51cf2407e2d 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools.py @@ -34,9 +34,7 @@ def await_result(func, *args, **kwargs): """If func returns an awaitable, raise that this runner can't handle it.""" result = func(*args, **kwargs) if hasattr(result, "__await__"): - raise TypeError( - "Policy {} returned awaitable object in non-async pipeline.".format(func) - ) + raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func)) return result diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 2a70d753f33b..e1296189f845 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -29,12 +29,7 @@ class _BearerTokenCredentialPolicyBase: :param str scopes: Lets you specify the type of access needed. """ - def __init__( - self, - credential: "TokenCredential", - *scopes: str, - **kwargs # pylint:disable=unused-argument - ) -> None: + def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs) -> None: # pylint:disable=unused-argument super(_BearerTokenCredentialPolicyBase, self).__init__() self._scopes = scopes self._credential = credential @@ -92,9 +87,7 @@ def on_request(self, request: "PipelineRequest") -> None: self._token = self._credential.get_token(*self._scopes) self._update_headers(request.http_request.headers, self._token.token) - def authorize_request( - self, request: "PipelineRequest", *scopes: str, **kwargs - ) -> None: + def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs) -> None: """Acquire a token from the credential and authorize the request with it. Keyword arguments are passed to the credential's get_token method. The token will be cached and used to @@ -134,9 +127,7 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response - def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Authorize request according to an authentication challenge This method is called when the resource provider responds 401 with a WWW-Authenticate header. @@ -148,9 +139,7 @@ def on_challenge( # pylint:disable=unused-argument,no-self-use return False - def on_response( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> None: + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: """Executed after the request comes back from the next policy. :param request: Request to be modified after returning from the policy. @@ -181,10 +170,7 @@ class AzureKeyCredentialPolicy(SansIOHTTPPolicy): """ def __init__( - self, - credential: "AzureKeyCredential", - name: str, - **kwargs # pylint: disable=unused-argument + self, credential: "AzureKeyCredential", name: str, **kwargs # pylint: disable=unused-argument ) -> None: super(AzureKeyCredentialPolicy, self).__init__() self._credential = credential @@ -206,11 +192,7 @@ class AzureSasCredentialPolicy(SansIOHTTPPolicy): :raises: ValueError or TypeError """ - def __init__( - self, - credential: "AzureSasCredential", - **kwargs # pylint: disable=unused-argument - ) -> None: + def __init__(self, credential: "AzureSasCredential", **kwargs) -> None: # pylint: disable=unused-argument super(AzureSasCredentialPolicy, self).__init__() if not credential: raise ValueError("credential can not be None") diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 9bd4730e569b..4460b553024c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -28,9 +28,7 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy): :param str scopes: Lets you specify the type of access needed. """ - def __init__( - self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any - ) -> None: + def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any) -> None: # pylint:disable=unused-argument super().__init__() self._credential = credential @@ -38,18 +36,14 @@ def __init__( self._scopes = scopes self._token: Optional["AccessToken"] = None - async def on_request( - self, request: "PipelineRequest" - ) -> None: # pylint:disable=invalid-overridden-method + async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method """Adds a bearer token Authorization header to request and sends request to next policy. :param request: The pipeline request object to be modified. :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ - _BearerTokenCredentialPolicyBase._enforce_https( # pylint:disable=protected-access - request - ) + _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access if self._token is None or self._need_new_token(): async with self._lock: @@ -58,9 +52,7 @@ async def on_request( self._token = await self._credential.get_token(*self._scopes) request.http_request.headers["Authorization"] = "Bearer " + self._token.token - async def authorize_request( - self, request: "PipelineRequest", *scopes: str, **kwargs: Any - ) -> None: + async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. Keyword arguments are passed to the credential's get_token method. The token will be cached and used to @@ -103,9 +95,7 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": return response - async def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Authorize request according to an authentication challenge This method is called when the resource provider responds 401 with a WWW-Authenticate header. @@ -117,9 +107,7 @@ async def on_challenge( # pylint:disable=unused-argument,no-self-use return False - def on_response( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> Optional[Awaitable[None]]: + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> Optional[Awaitable[None]]: """Executed after the request comes back from the next policy. :param request: Request to be modified after returning from the policy. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_base.py b/sdk/core/azure-core/azure/core/pipeline/policies/_base.py index 8ed9fbb17ebd..078feda760d6 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_base.py @@ -94,9 +94,7 @@ class SansIOHTTPPolicy(Generic[HTTPRequestTypeVar, HTTPResponseTypeVar]): but they will then be tied to AsyncPipeline usage. """ - def on_request( - self, request: PipelineRequest[HTTPRequestTypeVar] - ) -> Union[None, Awaitable[None]]: + def on_request(self, request: PipelineRequest[HTTPRequestTypeVar]) -> Union[None, Awaitable[None]]: """Is executed before sending the request from next policy. :param request: Request to be modified before sent from next policy. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_custom_hook.py b/sdk/core/azure-core/azure/core/pipeline/policies/_custom_hook.py index c8c34eb16ef4..db4c41c6706e 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_custom_hook.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_custom_hook.py @@ -35,15 +35,11 @@ class CustomHookPolicy(SansIOHTTPPolicy): :keyword callback raw_response_hook: Callback function. Will be invoked on response. """ - def __init__( - self, **kwargs - ): # pylint: disable=unused-argument,super-init-not-called + def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called self._request_callback = kwargs.get("raw_request_hook") self._response_callback = kwargs.get("raw_response_hook") - def on_request( - self, request: PipelineRequest - ) -> None: # pylint: disable=arguments-differ + def on_request(self, request: PipelineRequest) -> None: # pylint: disable=arguments-differ """This is executed before sending the request to the next policy. :param request: The PipelineRequest object. diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py b/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py index 05f31620ffd0..5a81a8473e6d 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py @@ -81,9 +81,7 @@ class DistributedTracingPolicy(SansIOHTTPPolicy): _RESPONSE_ID = "x-ms-request-id" def __init__(self, **kwargs): - self._network_span_namer = kwargs.get( - "network_span_namer", _default_network_span_namer - ) + self._network_span_namer = kwargs.get("network_span_namer", _default_network_span_namer) self._tracing_attributes = kwargs.get("tracing_attributes", {}) def on_request(self, request: "PipelineRequest") -> None: @@ -126,17 +124,13 @@ def end_span( if request_id is not None: span.add_attribute(self._REQUEST_ID, request_id) if response and self._RESPONSE_ID in response.headers: - span.add_attribute( - self._RESPONSE_ID, response.headers[self._RESPONSE_ID] - ) + span.add_attribute(self._RESPONSE_ID, response.headers[self._RESPONSE_ID]) if exc_info: span.__exit__(*exc_info) else: span.finish() - def on_response( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> None: + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: self.end_span(request, response=response.http_response) def on_exception(self, request: "PipelineRequest") -> None: diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_redirect.py b/sdk/core/azure-core/azure/core/pipeline/policies/_redirect.py index 823216343d07..d08a5395a1f5 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_redirect.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_redirect.py @@ -48,9 +48,7 @@ def __init__(self, **kwargs): self.max_redirects = kwargs.get("redirect_max", 30) remove_headers = set(kwargs.get("redirect_remove_headers", [])) - self._remove_headers_on_redirect = remove_headers.union( - self.REDIRECT_HEADERS_BLACKLIST - ) + self._remove_headers_on_redirect = remove_headers.union(self.REDIRECT_HEADERS_BLACKLIST) redirect_status = set(kwargs.get("redirect_on_status_codes", [])) self._redirect_on_status_codes = redirect_status.union(self.REDIRECT_STATUSES) super(RedirectPolicyBase, self).__init__() @@ -107,9 +105,7 @@ def increment(self, settings, response, redirect_location): """ # TODO: Revise some of the logic here. settings["redirects"] -= 1 - settings["history"].append( - RequestHistory(response.http_request, http_response=response.http_response) - ) + settings["history"].append(RequestHistory(response.http_request, http_response=response.http_response)) redirected = urlparse(redirect_location) if not redirected.netloc: @@ -160,9 +156,7 @@ def send(self, request): response = self.next.send(request) redirect_location = self.get_redirect_location(response) if redirect_location and redirect_settings["allow"]: - retryable = self.increment( - redirect_settings, response, redirect_location - ) + retryable = self.increment(redirect_settings, response, redirect_location) request.http_request = response.http_request continue return response diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_redirect_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_redirect_async.py index bd8536daa725..bd7738042a5c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_redirect_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_redirect_async.py @@ -62,9 +62,7 @@ async def send(self, request): # pylint:disable=invalid-overridden-method response = await self.next.send(request) redirect_location = self.get_redirect_location(response) if redirect_location and redirect_settings["allow"]: - redirects_remaining = self.increment( - redirect_settings, response, redirect_location - ) + redirects_remaining = self.increment(redirect_settings, response, redirect_location) request.http_request = response.http_request continue return response diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_retry.py b/sdk/core/azure-core/azure/core/pipeline/policies/_retry.py index 22e9d291f23e..b63dce649bd8 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_retry.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_retry.py @@ -74,9 +74,7 @@ def __init__(self, **kwargs): retry_codes = self._RETRY_CODES status_codes = kwargs.pop("retry_on_status_codes", []) self._retry_on_status_codes = set(status_codes) | retry_codes - self._method_whitelist = frozenset( - ["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"] - ) + self._method_whitelist = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]) self._respect_retry_after_header = True super(RetryPolicyBase, self).__init__() @@ -164,11 +162,7 @@ def _is_method_retryable(self, settings, request, response=None): :return: True if method should be retried upon. False if not in method allowlist. :rtype: bool """ - if ( - response - and request.method.upper() in ["POST", "PATCH"] - and response.status_code in [500, 503, 504] - ): + if response and request.method.upper() in ["POST", "PATCH"] and response.status_code in [500, 503, 504]: return True if request.method.upper() not in settings["methods"]: return False @@ -201,14 +195,9 @@ def is_retry(self, settings, response): has_retry_after = bool(response.http_response.headers.get("Retry-After")) if has_retry_after and self._respect_retry_after_header: return True - if not self._is_method_retryable( - settings, response.http_request, response=response.http_response - ): + if not self._is_method_retryable(settings, response.http_request, response=response.http_response): return False - return ( - settings["total"] - and response.http_response.status_code in self._retry_on_status_codes - ) + return settings["total"] and response.http_response.status_code in self._retry_on_status_codes def is_exhausted(self, settings): """Checks if any retries left. @@ -243,39 +232,28 @@ def increment(self, settings, response=None, error=None): """ settings["total"] -= 1 - if ( - isinstance(response, PipelineResponse) - and response.http_response.status_code == 202 - ): + if isinstance(response, PipelineResponse) and response.http_response.status_code == 202: return False if error and self._is_connection_error(error): # Connect retry? settings["connect"] -= 1 - settings["history"].append( - RequestHistory(response.http_request, error=error) - ) + settings["history"].append(RequestHistory(response.http_request, error=error)) elif error and self._is_read_error(error): # Read retry? settings["read"] -= 1 if hasattr(response, "http_request"): - settings["history"].append( - RequestHistory(response.http_request, error=error) - ) + settings["history"].append(RequestHistory(response.http_request, error=error)) else: # Incrementing because of a server error like a 500 in # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - if hasattr(response, "http_request") and hasattr( - response, "http_response" - ): + if hasattr(response, "http_request") and hasattr(response, "http_response"): settings["history"].append( - RequestHistory( - response.http_request, http_response=response.http_response - ) + RequestHistory(response.http_request, http_response=response.http_response) ) if self.is_exhausted(settings): @@ -323,16 +301,12 @@ def _configure_timeout(self, request, absolute_timeout, is_response_error): # if connection_timeout is already set, ensure it doesn't exceed absolute_timeout connection_timeout = request.context.options.get("connection_timeout") if connection_timeout: - request.context.options["connection_timeout"] = min( - connection_timeout, absolute_timeout - ) + request.context.options["connection_timeout"] = min(connection_timeout, absolute_timeout) # otherwise, try to ensure the transport's configured connection_timeout doesn't exceed absolute_timeout # ("connection_config" isn't defined on Async/HttpTransport but all implementations in this library have it) elif hasattr(request.context.transport, "connection_config"): - default_timeout = getattr( - request.context.transport.connection_config, "timeout", absolute_timeout - ) + default_timeout = getattr(request.context.transport.connection_config, "timeout", absolute_timeout) try: if absolute_timeout < default_timeout: request.context.options["connection_timeout"] = absolute_timeout @@ -475,9 +449,7 @@ def send(self, request): if self.is_retry(retry_settings, response): retry_active = self.increment(retry_settings, response=response) if retry_active: - self.sleep( - retry_settings, request.context.transport, response=response - ) + self.sleep(retry_settings, request.context.transport, response=response) is_response_error = True continue break @@ -486,12 +458,8 @@ def send(self, request): # succeed--we'll never have a response to it, so propagate the exception raise except AzureError as err: - if absolute_timeout > 0 and self._is_method_retryable( - retry_settings, request.http_request - ): - retry_active = self.increment( - retry_settings, response=request, error=err - ) + if absolute_timeout > 0 and self._is_method_retryable(retry_settings, request.http_request): + retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) if isinstance(err, ServiceRequestError): diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_retry_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_retry_async.py index d9df848f6d0d..38a57bb7056e 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_retry_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_retry_async.py @@ -75,9 +75,7 @@ class AsyncRetryPolicy(RetryPolicyBase, AsyncHTTPPolicy): :caption: Configuring an async retry policy. """ - async def _sleep_for_retry( - self, response, transport - ): # pylint:disable=invalid-overridden-method + async def _sleep_for_retry(self, response, transport): # pylint:disable=invalid-overridden-method """Sleep based on the Retry-After response header value. :param response: The PipelineResponse object. @@ -90,9 +88,7 @@ async def _sleep_for_retry( return True return False - async def _sleep_backoff( - self, settings, transport - ): # pylint:disable=invalid-overridden-method + async def _sleep_backoff(self, settings, transport): # pylint:disable=invalid-overridden-method """Sleep using exponential backoff. Immediately returns if backoff is 0. :param dict settings: The retry settings. @@ -103,9 +99,7 @@ async def _sleep_backoff( return await transport.sleep(backoff) - async def sleep( - self, settings, transport, response=None - ): # pylint:disable=invalid-overridden-method + async def sleep(self, settings, transport, response=None): # pylint:disable=invalid-overridden-method """Sleep between retry attempts. This method will respect a server's ``Retry-After`` response header @@ -150,9 +144,7 @@ async def send(self, request): # pylint:disable=invalid-overridden-method if self.is_retry(retry_settings, response): retry_active = self.increment(retry_settings, response=response) if retry_active: - await self.sleep( - retry_settings, request.context.transport, response=response - ) + await self.sleep(retry_settings, request.context.transport, response=response) is_response_error = True continue break @@ -161,12 +153,8 @@ async def send(self, request): # pylint:disable=invalid-overridden-method # succeed--we'll never have a response to it, so propagate the exception raise except AzureError as err: - if absolute_timeout > 0 and self._is_method_retryable( - retry_settings, request.http_request - ): - retry_active = self.increment( - retry_settings, response=request, error=err - ) + if absolute_timeout > 0 and self._is_method_retryable(retry_settings, request.http_request): + retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) if isinstance(err, ServiceRequestError): diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py b/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py index fea1da1c7d77..9751be46de7c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_universal.py @@ -234,9 +234,7 @@ def __init__( def user_agent(self) -> str: """The current user agent value.""" if self.use_env: - add_user_agent_header = os.environ.get( - self._ENV_ADDITIONAL_USER_AGENT, None - ) + add_user_agent_header = os.environ.get(self._ENV_ADDITIONAL_USER_AGENT, None) if add_user_agent_header is not None: return "{} {}".format(self._user_agent, add_user_agent_header) return self._user_agent @@ -285,9 +283,7 @@ class NetworkTraceLoggingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseTy :caption: Configuring a network trace logging policy. """ - def __init__( - self, logging_enable=False, **kwargs - ): # pylint: disable=unused-argument + def __init__(self, logging_enable=False, **kwargs): # pylint: disable=unused-argument self.enable_http_logger = logging_enable def on_request( @@ -367,9 +363,7 @@ def on_response( if header and pattern.match(header): filename = header.partition("=")[2] log_string += "\nFile attachments: {}".format(filename) - elif http_response.headers.get("content-type", "").endswith( - "octet-stream" - ): + elif http_response.headers.get("content-type", "").endswith("octet-stream"): log_string += "\nBody contains binary data." elif http_response.headers.get("content-type", "").startswith("image"): log_string += "\nBody contains image data." @@ -435,31 +429,17 @@ class HttpLoggingPolicy( MULTI_RECORD_LOG = "AZURE_SDK_LOGGING_MULTIRECORD" def __init__(self, logger=None, **kwargs): # pylint: disable=unused-argument - self.logger = logger or logging.getLogger( - "azure.core.pipeline.policies.http_logging_policy" - ) + self.logger = logger or logging.getLogger("azure.core.pipeline.policies.http_logging_policy") self.allowed_query_params = set() self.allowed_header_names = set(self.__class__.DEFAULT_HEADERS_ALLOWLIST) def _redact_query_param(self, key, value): - lower_case_allowed_query_params = [ - param.lower() for param in self.allowed_query_params - ] - return ( - value - if key.lower() in lower_case_allowed_query_params - else HttpLoggingPolicy.REDACTED_PLACEHOLDER - ) + lower_case_allowed_query_params = [param.lower() for param in self.allowed_query_params] + return value if key.lower() in lower_case_allowed_query_params else HttpLoggingPolicy.REDACTED_PLACEHOLDER def _redact_header(self, key, value): - lower_case_allowed_header_names = [ - header.lower() for header in self.allowed_header_names - ] - return ( - value - if key.lower() in lower_case_allowed_header_names - else HttpLoggingPolicy.REDACTED_PLACEHOLDER - ) + lower_case_allowed_header_names = [header.lower() for header in self.allowed_header_names] + return value if key.lower() in lower_case_allowed_header_names else HttpLoggingPolicy.REDACTED_PLACEHOLDER def on_request( # pylint: disable=too-many-return-statements self, request: PipelineRequest[HTTPRequestType] @@ -473,9 +453,7 @@ def on_request( # pylint: disable=too-many-return-statements # Get logger in my context first (request has been retried) # then read from kwargs (pop if that's the case) # then use my instance logger - logger = request.context.setdefault( - "logger", options.pop("logger", self.logger) - ) + logger = request.context.setdefault("logger", options.pop("logger", self.logger)) if not logger.isEnabledFor(logging.INFO): return @@ -483,9 +461,7 @@ def on_request( # pylint: disable=too-many-return-statements try: parsed_url = list(urllib.parse.urlparse(http_request.url)) parsed_qp = urllib.parse.parse_qsl(parsed_url[4], keep_blank_values=True) - filtered_qp = [ - (key, self._redact_query_param(key, value)) for key, value in parsed_qp - ] + filtered_qp = [(key, self._redact_query_param(key, value)) for key, value in parsed_qp] # 4 is query parsed_url[4] = "&".join(["=".join(part) for part in filtered_qp]) redacted_url = urllib.parse.urlunparse(parsed_url) @@ -551,9 +527,7 @@ def on_response( # then use my instance logger # If on_request was called, should always read from context options = request.context.options - logger = request.context.setdefault( - "logger", options.pop("logger", self.logger) - ) + logger = request.context.setdefault("logger", options.pop("logger", self.logger)) try: if not logger.isEnabledFor(logging.INFO): @@ -590,11 +564,7 @@ class ContentDecodePolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): # Name used in context CONTEXT_NAME = "deserialized_data" - def __init__( - self, - response_encoding: Optional[str] = None, - **kwargs # pylint: disable=unused-argument - ) -> None: + def __init__(self, response_encoding: Optional[str] = None, **kwargs) -> None: # pylint: disable=unused-argument self._response_encoding = response_encoding @classmethod @@ -663,9 +633,7 @@ def _json_attemp(data): # The function hack is because Py2.7 messes up with exception # context otherwise. _LOGGER.critical("Wasn't XML not JSON, failing") - raise_with_traceback( - DecodeError, message="XML is invalid", response=response - ) + raise_with_traceback(DecodeError, message="XML is invalid", response=response) elif mime_type.startswith("text/"): return data_as_str raise DecodeError("Cannot deserialize content-type: {}".format(mime_type)) @@ -706,9 +674,7 @@ def deserialize_from_http_generics( # even if it's likely dead code if not inspect.iscoroutinefunction(response.read): # type: ignore response.read() # type: ignore - return cls.deserialize_from_text( - response.text(encoding), mime_type, response=response - ) + return cls.deserialize_from_text(response.text(encoding), mime_type, response=response) def on_request(self, request: PipelineRequest) -> None: options = request.context.options @@ -767,9 +733,7 @@ class ProxyPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): :caption: Configuring a proxy policy. """ - def __init__( - self, proxies=None, **kwargs - ): # pylint: disable=unused-argument,super-init-not-called + def __init__(self, proxies=None, **kwargs): # pylint: disable=unused-argument,super-init-not-called self.proxies = proxies def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/__init__.py b/sdk/core/azure-core/azure/core/pipeline/transport/__init__.py index ef473d245463..cb72c6a914ce 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/__init__.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/__init__.py @@ -114,6 +114,4 @@ def __getattr__(name): raise ImportError("trio package is not installed") if transport: return transport - raise AttributeError( - f"module 'azure.core.pipeline.transport' has no attribute {name}" - ) + raise AttributeError(f"module 'azure.core.pipeline.transport' has no attribute {name}") diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index 3c75ec9bad1f..8678e5e2a7f8 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -86,18 +86,9 @@ class AioHttpTransport(AsyncHttpTransport): :caption: Asynchronous transport with aiohttp. """ - def __init__( - self, - *, - session: Optional[aiohttp.ClientSession] = None, - loop=None, - session_owner=True, - **kwargs - ): + def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner=True, **kwargs): if loop and sys.version_info >= (3, 10): - raise ValueError( - "Starting with Python 3.10, asyncio doesn’t support loop as a parameter anymore" - ) + raise ValueError("Starting with Python 3.10, asyncio doesn’t support loop as a parameter anymore") self._loop = loop self._session_owner = session_owner self.session = session @@ -154,18 +145,14 @@ def _get_request_data(self, request): # pylint: disable=no-self-use for form_file, data in request.files.items(): content_type = data[2] if len(data) > 2 else None try: - form_data.add_field( - form_file, data[1], filename=data[0], content_type=content_type - ) + form_data.add_field(form_file, data[1], filename=data[0], content_type=content_type) except IndexError: raise ValueError("Invalid formdata formatting: {}".format(data)) return form_data return request.data @overload - async def send( - self, request: HttpRequest, **config: Any - ) -> Optional[AsyncHttpResponse]: + async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpResponse]: """Send the request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. @@ -183,9 +170,7 @@ async def send( """ @overload - async def send( - self, request: "RestHttpRequest", **config: Any - ) -> Optional["RestAsyncHttpResponse"]: + async def send(self, request: "RestHttpRequest", **config: Any) -> Optional["RestAsyncHttpResponse"]: """Send the `azure.core.rest` request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. @@ -248,12 +233,8 @@ async def send(self, request, **config): try: stream_response = config.pop("stream", False) timeout = config.pop("connection_timeout", self.connection_config.timeout) - read_timeout = config.pop( - "read_timeout", self.connection_config.read_timeout - ) - socket_timeout = aiohttp.ClientTimeout( - sock_connect=timeout, sock_read=read_timeout - ) + read_timeout = config.pop("read_timeout", self.connection_config.read_timeout) + socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout) result = await self.session.request( # type: ignore request.method, request.url, @@ -301,9 +282,7 @@ class AioHttpStreamDownloadGenerator(AsyncIterator): on the *content-encoding* header. """ - def __init__( - self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompress=True - ) -> None: + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompress=True) -> None: self.pipeline = pipeline self.request = response.request self.response = response @@ -365,16 +344,9 @@ class AioHttpTransportResponse(AsyncHttpResponse): """ def __init__( - self, - request: HttpRequest, - aiohttp_response: aiohttp.ClientResponse, - block_size=None, - *, - decompress=True + self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None, *, decompress=True ) -> None: - super(AioHttpTransportResponse, self).__init__( - request, aiohttp_response, block_size=block_size - ) + super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) # https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse self.status_code = aiohttp_response.status self.headers = CIMultiDict(aiohttp_response.headers) @@ -411,9 +383,7 @@ def text(self, encoding: Optional[str] = None) -> str: except LookupError: encoding = None if not encoding: - if mimetype.type == "application" and ( - mimetype.subtype == "json" or mimetype.subtype == "rdap" - ): + if mimetype.type == "application" and (mimetype.subtype == "json" or mimetype.subtype == "rdap"): # RFC 7159 states that the default encoding is UTF-8. # RFC 7483 defines application/rdap+json encoding = "utf-8" @@ -458,8 +428,6 @@ def __getstate__(self): state = self.__dict__.copy() # Remove the unpicklable entries. - state[ - "internal_response" - ] = None # aiohttp response are not pickable (see headers comments) + state["internal_response"] = None # aiohttp response are not pickable (see headers comments) state["headers"] = CIMultiDict(self.headers) # MultiDictProxy is not pickable return state diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 1a5a568a6325..77c286abb1e2 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -94,9 +94,7 @@ def _format_url_section(template, **kwargs): return template.format(**kwargs) except KeyError as key: formatted_components = template.split("/") - components = [ - c for c in formatted_components if "{{{}}}".format(key.args[0]) not in c - ] + components = [c for c in formatted_components if "{{{}}}".format(key.args[0]) not in c] template = "/".join(components) # No URL sections left - returning None @@ -114,9 +112,7 @@ def _urljoin(base_url: str, stub_url: str) -> str: return parsed.geturl() -class HttpTransport( - AbstractContextManager, abc.ABC, Generic[HTTPRequestType, HTTPResponseType] -): +class HttpTransport(AbstractContextManager, abc.ABC, Generic[HTTPRequestType, HTTPResponseType]): """An http sender ABC.""" @abc.abstractmethod @@ -206,9 +202,7 @@ def body(self, value: DataType): self.data = value @staticmethod - def _format_data( - data: Union[str, IO] - ) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]: + def _format_data(data: Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]: """Format field data according to whether it is a stream or a string for a form-data request. @@ -235,9 +229,7 @@ def set_streamed_data_body(self, data): if not isinstance(data, binary_type) and not any( hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"] ): - raise TypeError( - "A streamable data source must be an open file-like object or iterable." - ) + raise TypeError("A streamable data source must be an open file-like object or iterable.") self.data = data self.files = None @@ -294,9 +286,7 @@ def set_formdata_body(self, data=None): self.data = {f: d for f, d in data.items() if d is not None} self.files = None else: # Assume "multipart/form-data" - self.files = { - f: self._format_data(d) for f, d in data.items() if d is not None - } + self.files = {f: self._format_data(d) for f, d in data.items() if d is not None} self.data = None def set_bytes_body(self, data): @@ -405,9 +395,7 @@ def _decode_parts( requests: List[HttpRequest], ) -> List["HttpResponse"]: """Rebuild an HTTP response from pure string.""" - return _decode_parts_helper( - self, message, http_response_type, requests, _deserialize_response - ) + return _decode_parts_helper(self, message, http_response_type, requests, _deserialize_response) def _get_raw_parts( self, http_response_type: Optional[Type["_HttpResponseBase"]] = None @@ -417,9 +405,7 @@ def _get_raw_parts( If parts are application/http use http_response_type or HttpClientTransportResponse as envelope. """ - return _get_raw_parts_helper( - self, http_response_type or HttpClientTransportResponse - ) + return _get_raw_parts_helper(self, http_response_type or HttpClientTransportResponse) def raise_for_status(self) -> None: """Raises an HttpResponseError if the response has an error status code. @@ -430,12 +416,8 @@ def raise_for_status(self) -> None: def __repr__(self): # there doesn't have to be a content type - content_type_str = ( - ", Content-Type: {}".format(self.content_type) if self.content_type else "" - ) - return "<{}: {} {}{}>".format( - type(self).__name__, self.status_code, self.reason, content_type_str - ) + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "<{}: {} {}{}>".format(type(self).__name__, self.status_code, self.reason, content_type_str) class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method @@ -487,9 +469,7 @@ class HttpClientTransportResponse(_HttpClientTransportResponse, HttpResponse): """ -def _deserialize_response( - http_response_as_bytes, http_request, http_response_type=HttpClientTransportResponse -): +def _deserialize_response(http_response_as_bytes, http_request, http_response_type=HttpClientTransportResponse): local_socket = BytesIOSocket(http_response_as_bytes) response = _HTTPResponse(local_socket, method=http_request.method) response.begin() @@ -602,9 +582,7 @@ def get( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request( - "GET", url, params, headers, content, form_content, None - ) + request = self._request("GET", url, params, headers, content, form_content, None) request.method = "GET" return request @@ -627,9 +605,7 @@ def put( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request( - "PUT", url, params, headers, content, form_content, stream_content - ) + request = self._request("PUT", url, params, headers, content, form_content, stream_content) return request def post( @@ -651,9 +627,7 @@ def post( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request( - "POST", url, params, headers, content, form_content, stream_content - ) + request = self._request("POST", url, params, headers, content, form_content, stream_content) return request def head( @@ -675,9 +649,7 @@ def head( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request( - "HEAD", url, params, headers, content, form_content, stream_content - ) + request = self._request("HEAD", url, params, headers, content, form_content, stream_content) return request def patch( @@ -699,9 +671,7 @@ def patch( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request( - "PATCH", url, params, headers, content, form_content, stream_content - ) + request = self._request("PATCH", url, params, headers, content, form_content, stream_content) return request def delete( @@ -722,9 +692,7 @@ def delete( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request( - "DELETE", url, params, headers, content, form_content, None - ) + request = self._request("DELETE", url, params, headers, content, form_content, None) return request def merge( @@ -745,17 +713,11 @@ def merge( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request( - "MERGE", url, params, headers, content, form_content, None - ) + request = self._request("MERGE", url, params, headers, content, form_content, None) return request def options( - self, - url: str, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - **kwargs + self, url: str, params: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None, **kwargs ) -> HttpRequest: """Create a OPTIONS request object. @@ -769,7 +731,5 @@ def options( """ content = kwargs.get("content") form_content = kwargs.get("form_content") - request = self._request( - "OPTIONS", url, params, headers, content, form_content, None - ) + request = self._request("OPTIONS", url, params, headers, content, form_content, None) return request diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py index aa749b10b4cb..6ddce089bc17 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py @@ -80,13 +80,9 @@ def parts(self) -> AsyncIterator: :raises ValueError: If the content is not multipart/mixed """ if not self.content_type or not self.content_type.startswith("multipart/mixed"): - raise ValueError( - "You can't get parts if the response is not multipart/mixed" - ) + raise ValueError("You can't get parts if the response is not multipart/mixed") - return _PartGenerator( - self, default_http_response_type=AsyncHttpClientTransportResponse - ) + return _PartGenerator(self, default_http_response_type=AsyncHttpClientTransportResponse) class AsyncHttpClientTransportResponse(_HttpClientTransportResponse, AsyncHttpResponse): diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py index 365703bc712a..2b0f4929e074 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py @@ -158,12 +158,8 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me headers=request.headers, data=data_to_send, files=request.files, - verify=kwargs.pop( - "connection_verify", self.connection_config.verify - ), - timeout=kwargs.pop( - "connection_timeout", self.connection_config.timeout - ), + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), cert=kwargs.pop("connection_cert", self.connection_config.cert), allow_redirects=False, **kwargs @@ -207,9 +203,7 @@ async def send(self, request, **kwargs): # pylint:disable=invalid-overridden-me await _handle_no_stream_rest_response(retval) return retval - return AsyncioRequestsTransportResponse( - request, response, self.connection_config.data_block_size - ) + return AsyncioRequestsTransportResponse(request, response, self.connection_config.data_block_size) class AsyncioStreamDownloadGenerator(AsyncIterator): @@ -221,25 +215,19 @@ class AsyncioStreamDownloadGenerator(AsyncIterator): on the *content-encoding* header. """ - def __init__( - self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs - ) -> None: + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: - raise TypeError( - "Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]) - ) + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: - self.iter_content_func = _read_raw_stream( - internal_response, self.block_size - ) + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) self.content_length = int(response.headers.get("Content-Length", 0)) def __len__(self): diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index e28fdd9419f1..359354907503 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -100,9 +100,7 @@ class _RequestsTransportResponseBase(_HttpResponseBase): """ def __init__(self, request, requests_response, block_size=None): - super(_RequestsTransportResponseBase, self).__init__( - request, requests_response, block_size=block_size - ) + super(_RequestsTransportResponseBase, self).__init__(request, requests_response, block_size=block_size) self.status_code = requests_response.status_code self.headers = requests_response.headers self.reason = requests_response.reason @@ -155,16 +153,12 @@ def __init__(self, pipeline, response, **kwargs): self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: - raise TypeError( - "Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]) - ) + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: - self.iter_content_func = _read_raw_stream( - internal_response, self.block_size - ) + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) self.content_length = int(response.headers.get("Content-Length", 0)) def __len__(self): @@ -319,21 +313,15 @@ def send(self, request, **kwargs): error: Optional[AzureErrorUnion] = None try: - connection_timeout = kwargs.pop( - "connection_timeout", self.connection_config.timeout - ) + connection_timeout = kwargs.pop("connection_timeout", self.connection_config.timeout) if isinstance(connection_timeout, tuple): if "read_timeout" in kwargs: - raise ValueError( - "Cannot set tuple connection_timeout and read_timeout together" - ) + raise ValueError("Cannot set tuple connection_timeout and read_timeout together") _LOGGER.warning("Tuple timeout setting is deprecated") timeout = connection_timeout else: - read_timeout = kwargs.pop( - "read_timeout", self.connection_config.read_timeout - ) + read_timeout = kwargs.pop("read_timeout", self.connection_config.read_timeout) timeout = (connection_timeout, read_timeout) response = self.session.request( # type: ignore request.method, @@ -385,6 +373,4 @@ def send(self, request, **kwargs): if not kwargs.get("stream"): _handle_non_stream_rest_response(retval) return retval - return RequestsTransportResponse( - request, response, self.connection_config.data_block_size - ) + return RequestsTransportResponse(request, response, self.connection_config.data_block_size) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py index 68aa033b77d3..8dfb9f48c169 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py @@ -81,25 +81,19 @@ class TrioStreamDownloadGenerator(AsyncIterator): on the *content-encoding* header. """ - def __init__( - self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs - ) -> None: + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size decompress = kwargs.pop("decompress", True) if len(kwargs) > 0: - raise TypeError( - "Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]) - ) + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) internal_response = response.internal_response if decompress: self.iter_content_func = internal_response.iter_content(self.block_size) else: - self.iter_content_func = _read_raw_stream( - internal_response, self.block_size - ) + self.iter_content_func = _read_raw_stream(internal_response, self.block_size) self.content_length = int(response.headers.get("Content-Length", 0)) def __len__(self): @@ -114,11 +108,9 @@ async def __anext__(self): self.iter_content_func, ) except AttributeError: # trio < 0.12.1 - chunk = ( - await trio.run_sync_in_worker_thread( # pylint: disable=no-member - _iterate_response_content, - self.iter_content_func, - ) + chunk = await trio.run_sync_in_worker_thread( # pylint: disable=no-member + _iterate_response_content, + self.iter_content_func, ) if not chunk: raise _ResponseStopIteration() @@ -206,9 +198,7 @@ async def send( # pylint:disable=invalid-overridden-method :keyword dict proxies: will define the proxy to use. Proxy is a dict (protocol, url) """ - async def send( - self, request, **kwargs: Any - ): # pylint:disable=invalid-overridden-method + async def send(self, request, **kwargs: Any): # pylint:disable=invalid-overridden-method """Send the request using this HTTP sender. :param request: The HttpRequest @@ -235,12 +225,8 @@ async def send( headers=request.headers, data=data_to_send, files=request.files, - verify=kwargs.pop( - "connection_verify", self.connection_config.verify - ), - timeout=kwargs.pop( - "connection_timeout", self.connection_config.timeout - ), + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), cert=kwargs.pop("connection_cert", self.connection_config.cert), allow_redirects=False, **kwargs @@ -248,29 +234,21 @@ async def send( limiter=trio_limiter, ) except AttributeError: # trio < 0.12.1 - response = ( - await trio.run_sync_in_worker_thread( # pylint: disable=no-member - functools.partial( - self.session.request, - request.method, - request.url, - headers=request.headers, - data=request.data, - files=request.files, - verify=kwargs.pop( - "connection_verify", self.connection_config.verify - ), - timeout=kwargs.pop( - "connection_timeout", self.connection_config.timeout - ), - cert=kwargs.pop( - "connection_cert", self.connection_config.cert - ), - allow_redirects=False, - **kwargs - ), - limiter=trio_limiter, - ) + response = await trio.run_sync_in_worker_thread( # pylint: disable=no-member + functools.partial( + self.session.request, + request.method, + request.url, + headers=request.headers, + data=request.data, + files=request.files, + verify=kwargs.pop("connection_verify", self.connection_config.verify), + timeout=kwargs.pop("connection_timeout", self.connection_config.timeout), + cert=kwargs.pop("connection_cert", self.connection_config.cert), + allow_redirects=False, + **kwargs + ), + limiter=trio_limiter, ) response.raw.enforce_content_length = True @@ -308,6 +286,4 @@ async def send( await _handle_no_stream_rest_response(retval) return retval - return TrioRequestsTransportResponse( - request, response, self.connection_config.data_block_size - ) + return TrioRequestsTransportResponse(request, response, self.connection_config.data_block_size) diff --git a/sdk/core/azure-core/azure/core/polling/_async_poller.py b/sdk/core/azure-core/azure/core/polling/_async_poller.py index 5b98c65ee76c..c50222363531 100644 --- a/sdk/core/azure-core/azure/core/polling/_async_poller.py +++ b/sdk/core/azure-core/azure/core/polling/_async_poller.py @@ -39,9 +39,7 @@ class AsyncPollingMethod(Generic[PollingReturnType]): """ABC class for polling method.""" - def initialize( - self, client: Any, initial_response: Any, deserialization_callback: Any - ) -> None: + def initialize(self, client: Any, initial_response: Any, deserialization_callback: Any) -> None: raise NotImplementedError("This method needs to be implemented") async def run(self) -> None: @@ -57,21 +55,11 @@ def resource(self) -> PollingReturnType: raise NotImplementedError("This method needs to be implemented") def get_continuation_token(self) -> str: - raise TypeError( - "Polling method '{}' doesn't support get_continuation_token".format( - self.__class__.__name__ - ) - ) + raise TypeError("Polling method '{}' doesn't support get_continuation_token".format(self.__class__.__name__)) @classmethod - def from_continuation_token( - cls, continuation_token: str, **kwargs - ) -> Tuple[Any, Any, Callable]: - raise TypeError( - "Polling method '{}' doesn't support from_continuation_token".format( - cls.__name__ - ) - ) + def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any, Any, Callable]: + raise TypeError("Polling method '{}' doesn't support from_continuation_token".format(cls.__name__)) class AsyncNoPolling(_NoPolling): @@ -83,9 +71,7 @@ async def run(self): # pylint:disable=invalid-overridden-method """ -async def async_poller( - client, initial_response, deserialization_callback, polling_method -): +async def async_poller(client, initial_response, deserialization_callback, polling_method): """Async Poller for long running operations. .. deprecated:: 1.5.0 @@ -101,9 +87,7 @@ async def async_poller( :param polling_method: The polling strategy to adopt :type polling_method: ~azure.core.polling.PollingMethod """ - poller = AsyncLROPoller( - client, initial_response, deserialization_callback, polling_method - ) + poller = AsyncLROPoller(client, initial_response, deserialization_callback, polling_method) return await poller @@ -137,9 +121,7 @@ def __init__( except AttributeError: pass - self._polling_method.initialize( - client, initial_response, deserialization_callback - ) + self._polling_method.initialize(client, initial_response, deserialization_callback) def polling_method(self) -> AsyncPollingMethod[PollingReturnType]: """Return the polling method associated to this poller.""" @@ -155,10 +137,7 @@ def continuation_token(self) -> str: @classmethod def from_continuation_token( - cls, - polling_method: AsyncPollingMethod[PollingReturnType], - continuation_token: str, - **kwargs + cls, polling_method: AsyncPollingMethod[PollingReturnType], continuation_token: str, **kwargs ) -> "AsyncLROPoller[PollingReturnType]": ( client, diff --git a/sdk/core/azure-core/azure/core/polling/_poller.py b/sdk/core/azure-core/azure/core/polling/_poller.py index c7d208d49548..84067f301392 100644 --- a/sdk/core/azure-core/azure/core/polling/_poller.py +++ b/sdk/core/azure-core/azure/core/polling/_poller.py @@ -41,9 +41,7 @@ class PollingMethod(Generic[PollingReturnType]): """ABC class for polling method.""" - def initialize( - self, client: Any, initial_response: Any, deserialization_callback: Any - ) -> None: + def initialize(self, client: Any, initial_response: Any, deserialization_callback: Any) -> None: raise NotImplementedError("This method needs to be implemented") def run(self) -> None: @@ -59,21 +57,11 @@ def resource(self) -> PollingReturnType: raise NotImplementedError("This method needs to be implemented") def get_continuation_token(self) -> str: - raise TypeError( - "Polling method '{}' doesn't support get_continuation_token".format( - self.__class__.__name__ - ) - ) + raise TypeError("Polling method '{}' doesn't support get_continuation_token".format(self.__class__.__name__)) @classmethod - def from_continuation_token( - cls, continuation_token: str, **kwargs - ) -> Tuple[Any, Any, Callable]: - raise TypeError( - "Polling method '{}' doesn't support from_continuation_token".format( - cls.__name__ - ) - ) + def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any, Any, Callable]: + raise TypeError("Polling method '{}' doesn't support from_continuation_token".format(cls.__name__)) class NoPolling(PollingMethod): @@ -83,9 +71,7 @@ def __init__(self): self._initial_response = None self._deserialization_callback = None - def initialize( - self, _: Any, initial_response: Any, deserialization_callback: Callable - ) -> None: + def initialize(self, _: Any, initial_response: Any, deserialization_callback: Callable) -> None: self._initial_response = initial_response self._deserialization_callback = deserialization_callback @@ -115,15 +101,11 @@ def get_continuation_token(self) -> str: return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii") @classmethod - def from_continuation_token( - cls, continuation_token: str, **kwargs - ) -> Tuple[Any, Any, Callable]: + def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any, Any, Callable]: try: deserialization_callback = kwargs["deserialization_callback"] except KeyError: - raise ValueError( - "Need kwarg 'deserialization_callback' to be recreated from continuation_token" - ) + raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") import pickle initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec @@ -161,9 +143,7 @@ def __init__( pass # Might raise a CloudError - self._polling_method.initialize( - client, initial_response, deserialization_callback - ) + self._polling_method.initialize(client, initial_response, deserialization_callback) # Prepare thread execution self._thread = None @@ -222,10 +202,7 @@ def continuation_token(self) -> str: @classmethod def from_continuation_token( - cls, - polling_method: PollingMethod[PollingReturnType], - continuation_token: str, - **kwargs + cls, polling_method: PollingMethod[PollingReturnType], continuation_token: str, **kwargs ) -> "LROPoller[PollingReturnType]": ( client, diff --git a/sdk/core/azure-core/azure/core/polling/async_base_polling.py b/sdk/core/azure-core/azure/core/polling/async_base_polling.py index 275f4b248d24..f57ae55eacdc 100644 --- a/sdk/core/azure-core/azure/core/polling/async_base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/async_base_polling.py @@ -46,9 +46,7 @@ async def run(self): # pylint:disable=invalid-overridden-method except BadStatus as err: self._status = "Failed" - raise HttpResponseError( - response=self._pipeline_response.http_response, error=err - ) + raise HttpResponseError(response=self._pipeline_response.http_response, error=err) except BadResponse as err: self._status = "Failed" @@ -59,9 +57,7 @@ async def run(self): # pylint:disable=invalid-overridden-method ) except OperationFailed as err: - raise HttpResponseError( - response=self._pipeline_response.http_response, error=err - ) + raise HttpResponseError(response=self._pipeline_response.http_response, error=err) async def _poll(self): # pylint:disable=invalid-overridden-method """Poll status of operation so long as operation is incomplete and @@ -99,15 +95,11 @@ async def _delay(self): # pylint:disable=invalid-overridden-method async def update_status(self): # pylint:disable=invalid-overridden-method """Update the current status of the LRO.""" - self._pipeline_response = await self.request_status( - self._operation.get_polling_url() - ) + self._pipeline_response = await self.request_status(self._operation.get_polling_url()) _raise_if_bad_http_status_and_method(self._pipeline_response.http_response) self._status = self._operation.get_status(self._pipeline_response) - async def request_status( - self, status_link - ): # pylint:disable=invalid-overridden-method + async def request_status(self, status_link): # pylint:disable=invalid-overridden-method """Do a simple GET to this status link. This method re-inject 'x-ms-client-request-id'. @@ -115,9 +107,7 @@ async def request_status( :rtype: azure.core.pipeline.PipelineResponse """ if self._path_format_arguments: - status_link = self._client.format_url( - status_link, **self._path_format_arguments - ) + status_link = self._client.format_url(status_link, **self._path_format_arguments) # Re-inject 'x-ms-client-request-id' while polling if "request_id" not in self._operation_config: self._operation_config["request_id"] = self._get_request_id() @@ -127,9 +117,7 @@ async def request_status( from azure.core.rest import HttpRequest as RestHttpRequest request = RestHttpRequest("GET", status_link) - return await self._client.send_request( - request, _return_pipeline_response=True, **self._operation_config - ) + return await self._client.send_request(request, _return_pipeline_response=True, **self._operation_config) # if I am a azure.core.pipeline.transport.HttpResponse request = self._client.get(status_link) diff --git a/sdk/core/azure-core/azure/core/polling/base_polling.py b/sdk/core/azure-core/azure/core/polling/base_polling.py index 7562db2f9131..0d8c0730e18e 100644 --- a/sdk/core/azure-core/azure/core/polling/base_polling.py +++ b/sdk/core/azure-core/azure/core/polling/base_polling.py @@ -107,11 +107,7 @@ def _raise_if_bad_http_status_and_method(response: "ResponseType") -> None: code = response.status_code if code in {200, 201, 202, 204}: return - raise BadStatus( - "Invalid return status {!r} for {!r} operation".format( - code, response.request.method - ) - ) + raise BadStatus("Invalid return status {!r} for {!r} operation".format(code, response.request.method)) def _is_empty(response: "ResponseType") -> bool: @@ -157,9 +153,7 @@ def get_status(self, pipeline_response: "PipelineResponseType") -> str: raise NotImplementedError() @abc.abstractmethod - def get_final_get_url( - self, pipeline_response: "PipelineResponseType" - ) -> Optional[str]: + def get_final_get_url(self, pipeline_response: "PipelineResponseType") -> Optional[str]: """If a final GET is needed, returns the URL. :rtype: str @@ -189,9 +183,7 @@ class OperationResourcePolling(LongRunningOperation): https://aka.ms/azsdk/autorest/openapi/lro-options """ - def __init__( - self, operation_location_header="operation-location", *, lro_options=None - ): + def __init__(self, operation_location_header="operation-location", *, lro_options=None): self._operation_location_header = operation_location_header # Store the initial URLs @@ -209,16 +201,13 @@ def get_polling_url(self) -> str: """Return the polling URL.""" return self._async_url - def get_final_get_url( - self, pipeline_response: "PipelineResponseType" - ) -> Optional[str]: + def get_final_get_url(self, pipeline_response: "PipelineResponseType") -> Optional[str]: """If a final GET is needed, returns the URL. :rtype: str """ if ( - self._lro_options.get(_LroOption.FINAL_STATE_VIA) - == _FinalStateViaOption.LOCATION_FINAL_STATE + self._lro_options.get(_LroOption.FINAL_STATE_VIA) == _FinalStateViaOption.LOCATION_FINAL_STATE and self._location_url ): return self._location_url @@ -276,9 +265,7 @@ def get_status(self, pipeline_response: "PipelineResponseType") -> str: """ response = pipeline_response.http_response if _is_empty(response): - raise BadResponse( - "The response from long running operation does not contain a body." - ) + raise BadResponse("The response from long running operation does not contain a body.") body = _as_json(response) status = body.get("status") @@ -302,9 +289,7 @@ def get_polling_url(self) -> str: """Return the polling URL.""" return self._location_url - def get_final_get_url( - self, pipeline_response: "PipelineResponseType" - ) -> Optional[str]: + def get_final_get_url(self, pipeline_response: "PipelineResponseType") -> Optional[str]: """If a final GET is needed, returns the URL. :rtype: str @@ -361,9 +346,7 @@ def set_initial_status(self, pipeline_response: "PipelineResponseType") -> str: def get_status(self, pipeline_response: "PipelineResponseType") -> str: return "Succeeded" - def get_final_get_url( - self, pipeline_response: "PipelineResponseType" - ) -> Optional[str]: + def get_final_get_url(self, pipeline_response: "PipelineResponseType") -> Optional[str]: """If a final GET is needed, returns the URL. :rtype: str @@ -383,12 +366,7 @@ class LROBasePolling(PollingMethod): # pylint: disable=too-many-instance-attrib """ def __init__( - self, - timeout=30, - lro_algorithms=None, - lro_options=None, - path_format_arguments=None, - **operation_config + self, timeout=30, lro_algorithms=None, lro_options=None, path_format_arguments=None, **operation_config ): self._lro_algorithms = lro_algorithms or [ OperationResourcePolling(lro_options=lro_options), @@ -412,9 +390,7 @@ def status(self): :rtype: str """ if not self._operation: - raise ValueError( - "set_initial_status was never called. Did you give this instance to a poller?" - ) + raise ValueError("set_initial_status was never called. Did you give this instance to a poller?") return self._status def finished(self): @@ -457,9 +433,7 @@ def initialize(self, client, initial_response, deserialization_callback): raise HttpResponseError(response=initial_response.http_response, error=err) except BadResponse as err: self._status = "Failed" - raise HttpResponseError( - response=initial_response.http_response, message=str(err), error=err - ) + raise HttpResponseError(response=initial_response.http_response, message=str(err), error=err) except OperationFailed as err: raise HttpResponseError(response=initial_response.http_response, error=err) @@ -469,30 +443,22 @@ def get_continuation_token(self) -> str: return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii") @classmethod - def from_continuation_token( - cls, continuation_token: str, **kwargs - ) -> Tuple[Any, Any, Callable]: + def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any, Any, Callable]: try: client = kwargs["client"] except KeyError: - raise ValueError( - "Need kwarg 'client' to be recreated from continuation_token" - ) + raise ValueError("Need kwarg 'client' to be recreated from continuation_token") try: deserialization_callback = kwargs["deserialization_callback"] except KeyError: - raise ValueError( - "Need kwarg 'deserialization_callback' to be recreated from continuation_token" - ) + raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") import pickle initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec # Restore the transport in the context - initial_response.context.transport = ( - client._pipeline._transport # pylint: disable=protected-access - ) + initial_response.context.transport = client._pipeline._transport # pylint: disable=protected-access return client, initial_response, deserialization_callback def run(self): @@ -501,9 +467,7 @@ def run(self): except BadStatus as err: self._status = "Failed" - raise HttpResponseError( - response=self._pipeline_response.http_response, error=err - ) + raise HttpResponseError(response=self._pipeline_response.http_response, error=err) except BadResponse as err: self._status = "Failed" @@ -514,9 +478,7 @@ def run(self): ) except OperationFailed as err: - raise HttpResponseError( - response=self._pipeline_response.http_response, error=err - ) + raise HttpResponseError(response=self._pipeline_response.http_response, error=err) def _poll(self): """Poll status of operation so long as operation is incomplete and @@ -542,9 +504,7 @@ def _poll(self): self._pipeline_response = self.request_status(final_get_url) _raise_if_bad_http_status_and_method(self._pipeline_response.http_response) - def _parse_resource( - self, pipeline_response: "PipelineResponseType" - ) -> Optional[Any]: + def _parse_resource(self, pipeline_response: "PipelineResponseType") -> Optional[Any]: """Assuming this response is a resource, use the deserialization callback to parse it. If body is empty, assuming no resource to return. """ @@ -578,9 +538,7 @@ def update_status(self): self._status = self._operation.get_status(self._pipeline_response) def _get_request_id(self): - return self._pipeline_response.http_response.request.headers[ - "x-ms-client-request-id" - ] + return self._pipeline_response.http_response.request.headers["x-ms-client-request-id"] def request_status(self, status_link): """Do a simple GET to this status link. @@ -590,9 +548,7 @@ def request_status(self, status_link): :rtype: azure.core.pipeline.PipelineResponse """ if self._path_format_arguments: - status_link = self._client.format_url( - status_link, **self._path_format_arguments - ) + status_link = self._client.format_url(status_link, **self._path_format_arguments) # Re-inject 'x-ms-client-request-id' while polling if "request_id" not in self._operation_config: self._operation_config["request_id"] = self._get_request_id() @@ -602,9 +558,7 @@ def request_status(self, status_link): from azure.core.rest import HttpRequest as RestHttpRequest request = RestHttpRequest("GET", status_link) - return self._client.send_request( - request, _return_pipeline_response=True, **self._operation_config - ) + return self._client.send_request(request, _return_pipeline_response=True, **self._operation_config) # if I am a azure.core.pipeline.transport.HttpResponse request = self._client.get(status_link) return self._client._pipeline.run( # pylint: disable=protected-access diff --git a/sdk/core/azure-core/azure/core/rest/_aiohttp.py b/sdk/core/azure-core/azure/core/rest/_aiohttp.py index 6b4775fc054b..62843ef26b04 100644 --- a/sdk/core/azure-core/azure/core/rest/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/rest/_aiohttp.py @@ -155,9 +155,7 @@ def __getattr__(self, attr): return super().__getattr__(attr) -class RestAioHttpTransportResponse( - AsyncHttpResponseImpl, _RestAioHttpTransportResponseBackcompatMixin -): +class RestAioHttpTransportResponse(AsyncHttpResponseImpl, _RestAioHttpTransportResponseBackcompatMixin): def __init__(self, *, internal_response, decompress: bool = True, **kwargs): headers = _CIMultiDict(internal_response.headers) super().__init__( @@ -176,9 +174,7 @@ def __init__(self, *, internal_response, decompress: bool = True, **kwargs): def __getstate__(self): state = self.__dict__.copy() # Remove the unpicklable entries. - state[ - "_internal_response" - ] = None # aiohttp response are not pickable (see headers comments) + state["_internal_response"] = None # aiohttp response are not pickable (see headers comments) state["headers"] = CIMultiDict(self.headers) # MultiDictProxy is not pickable return state diff --git a/sdk/core/azure-core/azure/core/rest/_helpers.py b/sdk/core/azure-core/azure/core/rest/_helpers.py index 42dbff631b37..bcc2a8c71c06 100644 --- a/sdk/core/azure-core/azure/core/rest/_helpers.py +++ b/sdk/core/azure-core/azure/core/rest/_helpers.py @@ -72,17 +72,9 @@ def _verify_data_object(name, value): if not isinstance(name, str): - raise TypeError( - "Invalid type for data name. Expected str, got {}: {}".format( - type(name), name - ) - ) + raise TypeError("Invalid type for data name. Expected str, got {}: {}".format(type(name), name)) if value is not None and not isinstance(value, (str, bytes, int, float)): - raise TypeError( - "Invalid type for data value. Expected primitive type, got {}: {}".format( - type(name), name - ) - ) + raise TypeError("Invalid type for data value. Expected primitive type, got {}: {}".format(type(name), name)) def set_urlencoded_body(data, has_files): @@ -105,9 +97,7 @@ def set_urlencoded_body(data, has_files): def set_multipart_body(files): - formatted_files = { - f: _format_data_helper(d) for f, d in files.items() if d is not None - } + formatted_files = {f: _format_data_helper(d) for f, d in files.items() if d is not None} return {}, formatted_files @@ -277,9 +267,7 @@ def _set_streamed_data_body(self, data): if not isinstance(data, binary_type) and not any( hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"] ): - raise TypeError( - "A streamable data source must be an open file-like object or iterable." - ) + raise TypeError("A streamable data source must be an open file-like object or iterable.") headers = self._set_body(content=data) self._files = None self.headers.update(headers) @@ -367,6 +355,6 @@ def _serialize(self): def _add_backcompat_properties(self, request, memo): """While deepcopying, we also need to add the private backcompat attrs""" - request._multipart_mixed_info = ( # pylint: disable=protected-access - copy.deepcopy(self._multipart_mixed_info, memo) + request._multipart_mixed_info = copy.deepcopy( # pylint: disable=protected-access + self._multipart_mixed_info, memo ) diff --git a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py index 4078696c7e1a..035f26ecedbe 100644 --- a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py +++ b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py @@ -96,9 +96,7 @@ def _decode_parts(self, message, http_response_type, requests): Rebuild an HTTP response from pure string. """ - def _deserialize_response( - http_response_as_bytes, http_request, http_response_type - ): + def _deserialize_response(http_response_as_bytes, http_request, http_response_type): local_socket = BytesIOSocket(http_response_as_bytes) response = _HTTPResponse(local_socket, method=http_request.method) response.begin() @@ -120,9 +118,7 @@ def _get_raw_parts(self, http_response_type=None): If parts are application/http use http_response_type or HttpClientTransportResponse as envelope. """ - return _get_raw_parts_helper( - self, http_response_type or RestHttpClientTransportResponse - ) + return _get_raw_parts_helper(self, http_response_type or RestHttpClientTransportResponse) def _stream_download(self, pipeline, **kwargs): """DEPRECATED: Generator for streaming request body data. @@ -180,9 +176,7 @@ def __init__(self, **kwargs) -> None: self._reason: str = kwargs.pop("reason") self._content_type: str = kwargs.pop("content_type") self._headers: MutableMapping[str, str] = kwargs.pop("headers") - self._stream_download_generator: Callable = kwargs.pop( - "stream_download_generator" - ) + self._stream_download_generator: Callable = kwargs.pop("stream_download_generator") self._is_closed = False self._is_stream_consumed = False self._json = None # this is filled in ContentDecodePolicy, when we deserialize @@ -323,17 +317,11 @@ def content(self) -> bytes: return self._content def __repr__(self) -> str: - content_type_str = ( - ", Content-Type: {}".format(self.content_type) if self.content_type else "" - ) - return "".format( - self.status_code, self.reason, content_type_str - ) + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "".format(self.status_code, self.reason, content_type_str) -class HttpResponseImpl( - _HttpResponseBaseImpl, _HttpResponse, HttpResponseBackcompatMixin -): +class HttpResponseImpl(_HttpResponseBaseImpl, _HttpResponse, HttpResponseBackcompatMixin): """HttpResponseImpl built on top of our HttpResponse protocol class. Since ~azure.core.rest.HttpResponse is an abstract base class, we need to @@ -404,25 +392,19 @@ def iter_raw(self, **kwargs) -> Iterator[bytes]: :rtype: Iterator[str] """ self._stream_download_check() - for part in self._stream_download_generator( - response=self, pipeline=None, decompress=False - ): + for part in self._stream_download_generator(response=self, pipeline=None, decompress=False): yield part self.close() -class _RestHttpClientTransportResponseBackcompatBaseMixin( - _HttpResponseBackcompatMixinBase -): +class _RestHttpClientTransportResponseBackcompatBaseMixin(_HttpResponseBackcompatMixinBase): def body(self): if self._content is None: self._content = self.internal_response.read() return self.content -class _RestHttpClientTransportResponseBase( - _HttpResponseBaseImpl, _RestHttpClientTransportResponseBackcompatBaseMixin -): +class _RestHttpClientTransportResponseBase(_HttpResponseBaseImpl, _RestHttpClientTransportResponseBackcompatBaseMixin): def __init__(self, **kwargs): internal_response = kwargs.pop("internal_response") headers = case_insensitive_dict(internal_response.getheaders()) @@ -437,9 +419,7 @@ def __init__(self, **kwargs): ) -class RestHttpClientTransportResponse( - _RestHttpClientTransportResponseBase, HttpResponseImpl -): +class RestHttpClientTransportResponse(_RestHttpClientTransportResponseBase, HttpResponseImpl): """Create a Rest HTTPResponse from an http.client response.""" def iter_bytes(self, **kwargs): diff --git a/sdk/core/azure-core/azure/core/rest/_http_response_impl_async.py b/sdk/core/azure-core/azure/core/rest/_http_response_impl_async.py index da5be566f4e3..b3449fb3e4f8 100644 --- a/sdk/core/azure-core/azure/core/rest/_http_response_impl_async.py +++ b/sdk/core/azure-core/azure/core/rest/_http_response_impl_async.py @@ -49,18 +49,12 @@ def parts(self): :raises ValueError: If the content is not multipart/mixed """ if not self.content_type or not self.content_type.startswith("multipart/mixed"): - raise ValueError( - "You can't get parts if the response is not multipart/mixed" - ) + raise ValueError("You can't get parts if the response is not multipart/mixed") - return _PartGenerator( - self, default_http_response_type=RestAsyncHttpClientTransportResponse - ) + return _PartGenerator(self, default_http_response_type=RestAsyncHttpClientTransportResponse) -class AsyncHttpResponseImpl( - _HttpResponseBaseImpl, _AsyncHttpResponse, AsyncHttpResponseBackcompatMixin -): +class AsyncHttpResponseImpl(_HttpResponseBaseImpl, _AsyncHttpResponse, AsyncHttpResponseBackcompatMixin): """AsyncHttpResponseImpl built on top of our HttpResponse protocol class. Since ~azure.core.rest.AsyncHttpResponse is an abstract base class, we need to @@ -103,9 +97,7 @@ async def iter_raw(self, **kwargs: Any) -> AsyncIterator[bytes]: :rtype: AsyncIterator[bytes] """ self._stream_download_check() - async for part in self._stream_download_generator( - response=self, pipeline=None, decompress=False - ): + async for part in self._stream_download_generator(response=self, pipeline=None, decompress=False): yield part await self.close() @@ -119,9 +111,7 @@ async def iter_bytes(self, **kwargs: Any) -> AsyncIterator[bytes]: yield self.content[i : i + self._block_size] else: self._stream_download_check() - async for part in self._stream_download_generator( - response=self, pipeline=None, decompress=True - ): + async for part in self._stream_download_generator(response=self, pipeline=None, decompress=True): yield part await self.close() @@ -139,17 +129,11 @@ async def __aexit__(self, *args) -> None: await self.close() def __repr__(self) -> str: - content_type_str = ( - ", Content-Type: {}".format(self.content_type) if self.content_type else "" - ) - return "".format( - self.status_code, self.reason, content_type_str - ) + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "".format(self.status_code, self.reason, content_type_str) -class RestAsyncHttpClientTransportResponse( - _RestHttpClientTransportResponseBase, AsyncHttpResponseImpl -): +class RestAsyncHttpClientTransportResponse(_RestHttpClientTransportResponseBase, AsyncHttpResponseImpl): """Create a Rest HTTPResponse from an http.client response.""" async def iter_bytes(self, **kwargs): diff --git a/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py b/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py index b787ec9f3a5c..35e896676e42 100644 --- a/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/rest/_requests_asyncio.py @@ -33,9 +33,7 @@ class RestAsyncioRequestsTransportResponse(AsyncHttpResponseImpl, _RestRequestsT """Asynchronous streaming of data from the response.""" def __init__(self, **kwargs): - super().__init__( - stream_download_generator=AsyncioStreamDownloadGenerator, **kwargs - ) + super().__init__(stream_download_generator=AsyncioStreamDownloadGenerator, **kwargs) async def close(self) -> None: """Close the response. diff --git a/sdk/core/azure-core/azure/core/rest/_requests_basic.py b/sdk/core/azure-core/azure/core/rest/_requests_basic.py index b07a1aa828c9..f00300c0c74d 100644 --- a/sdk/core/azure-core/azure/core/rest/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/rest/_requests_basic.py @@ -75,9 +75,7 @@ def _body(self): return self._content -class _RestRequestsTransportResponseBase( - _HttpResponseBaseImpl, _RestRequestsTransportResponseBaseMixin -): +class _RestRequestsTransportResponseBase(_HttpResponseBaseImpl, _RestRequestsTransportResponseBaseMixin): def __init__(self, **kwargs): internal_response = kwargs.pop("internal_response") content = None @@ -95,10 +93,6 @@ def __init__(self, **kwargs): ) -class RestRequestsTransportResponse( - HttpResponseImpl, _RestRequestsTransportResponseBase -): +class RestRequestsTransportResponse(HttpResponseImpl, _RestRequestsTransportResponseBase): def __init__(self, **kwargs): - super(RestRequestsTransportResponse, self).__init__( - stream_download_generator=StreamDownloadGenerator, **kwargs - ) + super(RestRequestsTransportResponse, self).__init__(stream_download_generator=StreamDownloadGenerator, **kwargs) diff --git a/sdk/core/azure-core/azure/core/rest/_requests_trio.py b/sdk/core/azure-core/azure/core/rest/_requests_trio.py index 3a5e7f720b89..8a8415d4d03b 100644 --- a/sdk/core/azure-core/azure/core/rest/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/rest/_requests_trio.py @@ -33,9 +33,7 @@ class RestTrioRequestsTransportResponse(AsyncHttpResponseImpl, _RestRequestsTran """Asynchronous streaming of data from the response.""" def __init__(self, **kwargs): - super().__init__( - stream_download_generator=TrioStreamDownloadGenerator, **kwargs - ) + super().__init__(stream_download_generator=TrioStreamDownloadGenerator, **kwargs) async def close(self) -> None: if not self.is_closed: diff --git a/sdk/core/azure-core/azure/core/rest/_rest_py3.py b/sdk/core/azure-core/azure/core/rest/_rest_py3.py index f65dc7537eb1..902c01d13821 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest_py3.py +++ b/sdk/core/azure-core/azure/core/rest/_rest_py3.py @@ -122,9 +122,7 @@ def __init__( if kwargs: raise TypeError( - "You have passed in kwargs '{}' that are not valid kwargs.".format( - "', '".join(list(kwargs.keys())) - ) + "You have passed in kwargs '{}' that are not valid kwargs.".format("', '".join(list(kwargs.keys()))) ) def _set_body( @@ -148,9 +146,7 @@ def _set_body( if files: default_headers, self._files = set_multipart_body(files) if data: - default_headers, self._data = set_urlencoded_body( - data, has_files=bool(files) - ) + default_headers, self._data = set_urlencoded_body(data, has_files=bool(files)) return default_headers @property @@ -371,12 +367,8 @@ def iter_bytes(self, **kwargs: Any) -> Iterator[bytes]: ... def __repr__(self) -> str: - content_type_str = ( - ", Content-Type: {}".format(self.content_type) if self.content_type else "" - ) - return "".format( - self.status_code, self.reason, content_type_str - ) + content_type_str = ", Content-Type: {}".format(self.content_type) if self.content_type else "" + return "".format(self.status_code, self.reason, content_type_str) class AsyncHttpResponse(_HttpResponseBase, AsyncContextManager["AsyncHttpResponse"]): diff --git a/sdk/core/azure-core/azure/core/settings.py b/sdk/core/azure-core/azure/core/settings.py index ee804cfc4cb0..2f078e9f9729 100644 --- a/sdk/core/azure-core/azure/core/settings.py +++ b/sdk/core/azure-core/azure/core/settings.py @@ -111,11 +111,7 @@ def convert_logging(value: Union[str, int]) -> int: val = cast(str, value).upper() level = _levels.get(val) if not level: - raise ValueError( - "Cannot convert {} to log level, valid values are: {}".format( - value, ", ".join(_levels) - ) - ) + raise ValueError("Cannot convert {} to log level, valid values are: {}".format(value, ", ".join(_levels))) return level @@ -142,9 +138,7 @@ def get_opencensus_span_if_opencensus_is_imported() -> Optional[Type[AbstractSpa } -def convert_tracing_impl( - value: Union[str, Type[AbstractSpan]] -) -> Optional[Type[AbstractSpan]]: +def convert_tracing_impl(value: Union[str, Type[AbstractSpan]]) -> Optional[Type[AbstractSpan]]: """Convert a string to AbstractSpan If a AbstractSpan is passed in, it is returned as-is. Otherwise the function @@ -205,9 +199,7 @@ class PrioritizedSetting: """ - def __init__( - self, name, env_var=None, system_hook=None, default=_Unset, convert=None - ): + def __init__(self, name, env_var=None, system_hook=None, default=_Unset, convert=None): self._name = name self._env_var = env_var @@ -384,11 +376,7 @@ def defaults(self): :rtype: namedtuple """ - props = { - k: v.default - for (k, v) in self.__class__.__dict__.items() - if isinstance(v, PrioritizedSetting) - } + props = {k: v.default for (k, v) in self.__class__.__dict__.items() if isinstance(v, PrioritizedSetting)} return self._config(props) @property @@ -412,11 +400,7 @@ def config(self, **kwargs): settings.config(log_level=logging.DEBUG) """ - props = { - k: v() - for (k, v) in self.__class__.__dict__.items() - if isinstance(v, PrioritizedSetting) - } + props = {k: v() for (k, v) in self.__class__.__dict__.items() if isinstance(v, PrioritizedSetting)} props.update(kwargs) return self._config(props) diff --git a/sdk/core/azure-core/azure/core/tracing/_abstract_span.py b/sdk/core/azure-core/azure/core/tracing/_abstract_span.py index 8e4305690440..b29103dbcb23 100644 --- a/sdk/core/azure-core/azure/core/tracing/_abstract_span.py +++ b/sdk/core/azure-core/azure/core/tracing/_abstract_span.py @@ -105,9 +105,7 @@ def add_attribute(self, key: str, value: Union[str, int]) -> None: :type value: str """ - def set_http_attributes( - self, request: "HttpRequest", response: Optional["HttpResponseType"] = None - ) -> None: + def set_http_attributes(self, request: "HttpRequest", response: Optional["HttpResponseType"] = None) -> None: """ Add correct attributes for a http client span. @@ -140,9 +138,7 @@ def link(cls, traceparent: str, attributes: Optional["Attributes"] = None) -> No """ @classmethod - def link_from_headers( - cls, headers: Dict[str, str], attributes: Optional["Attributes"] = None - ) -> None: + def link_from_headers(cls, headers: Dict[str, str], attributes: Optional["Attributes"] = None) -> None: """ Given a dictionary, extracts the context and links the context to the current tracer. @@ -207,9 +203,7 @@ class HttpSpanMixin(_MIXIN_BASE): _HTTP_URL = "http.url" _HTTP_STATUS_CODE = "http.status_code" - def set_http_attributes( - self, request: "HttpRequest", response: Optional["HttpResponseType"] = None - ) -> None: + def set_http_attributes(self, request: "HttpRequest", response: Optional["HttpResponseType"] = None) -> None: """ Add correct attributes for a http client span. @@ -240,8 +234,6 @@ class Link: :type attributes: dict """ - def __init__( - self, headers: Dict[str, str], attributes: Optional["Attributes"] = None - ) -> None: + def __init__(self, headers: Dict[str, str], attributes: Optional["Attributes"] = None) -> None: self.headers = headers self.attributes = attributes diff --git a/sdk/core/azure-core/azure/core/tracing/common.py b/sdk/core/azure-core/azure/core/tracing/common.py index b23186c855e8..f3741cdfdd04 100644 --- a/sdk/core/azure-core/azure/core/tracing/common.py +++ b/sdk/core/azure-core/azure/core/tracing/common.py @@ -51,9 +51,7 @@ def get_function_and_class_name(func: Callable, *args) -> str: return func.__qualname__ except AttributeError: if args: - return "{}.{}".format( - args[0].__class__.__name__, func.__name__ - ) # pylint: disable=protected-access + return "{}.{}".format(args[0].__class__.__name__, func.__name__) # pylint: disable=protected-access return func.__name__ diff --git a/sdk/core/azure-core/azure/core/tracing/decorator.py b/sdk/core/azure-core/azure/core/tracing/decorator.py index 6773d01d82de..02b1bcc0099b 100644 --- a/sdk/core/azure-core/azure/core/tracing/decorator.py +++ b/sdk/core/azure-core/azure/core/tracing/decorator.py @@ -50,9 +50,7 @@ def distributed_trace( # pylint:disable=function-redefined pass -def distributed_trace( # pylint:disable=function-redefined - __func: Callable[P, T] = None, **kwargs: Any -): +def distributed_trace(__func: Callable[P, T] = None, **kwargs: Any): # pylint:disable=function-redefined """Decorator to apply to function to get traced automatically. Span will use the func name or "name_of_span". diff --git a/sdk/core/azure-core/azure/core/tracing/decorator_async.py b/sdk/core/azure-core/azure/core/tracing/decorator_async.py index 96fe11ac88bc..ed82bdf09cb6 100644 --- a/sdk/core/azure-core/azure/core/tracing/decorator_async.py +++ b/sdk/core/azure-core/azure/core/tracing/decorator_async.py @@ -38,9 +38,7 @@ @overload -def distributed_trace_async( - __func: Callable[P, Awaitable[T]] -) -> Callable[P, Awaitable[T]]: +def distributed_trace_async(__func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: pass diff --git a/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py b/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py index 62a2f50989a8..50d5001546f1 100644 --- a/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py +++ b/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py @@ -7,9 +7,7 @@ from typing import Mapping -def parse_connection_string( - conn_str: str, case_sensitive_keys: bool = False -) -> Mapping[str, str]: +def parse_connection_string(conn_str: str, case_sensitive_keys: bool = False) -> Mapping[str, str]: """Parses the connection string into a dict of its component parts, with the option of preserving case of keys, and validates that each key in the connection string has a provided value. If case of keys is not preserved (ie. `case_sensitive_keys=False`), then a dict with LOWERCASE KEYS will be returned. @@ -38,9 +36,7 @@ def parse_connection_string( for key in args_dict.keys(): new_key = key.lower() if new_key in new_args_dict: - raise ValueError( - "Duplicate key in connection string: {}".format(new_key) - ) + raise ValueError("Duplicate key in connection string: {}".format(new_key)) new_args_dict[new_key] = args_dict[key] return new_args_dict diff --git a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py index b9f280390e69..e15964dc5e5a 100644 --- a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py +++ b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py @@ -80,9 +80,7 @@ def _format_parameters_helper(http_request, params): query = urlparse(http_request.url).query if query: http_request.url = http_request.url.partition("?")[0] - existing_params = { - p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")] - } + existing_params = {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} params.update(existing_params) query_params = [] for k, v in params.items(): @@ -112,9 +110,7 @@ def _pad_attr_name(attr: str, backcompat_attrs: List[str]) -> str: return "_{}".format(attr) if attr in backcompat_attrs else attr -def _prepare_multipart_body_helper( - http_request: "HTTPRequestType", content_index: int = 0 -) -> int: +def _prepare_multipart_body_helper(http_request: "HTTPRequestType", content_index: int = 0) -> int: """Helper for prepare_multipart_body. Will prepare the body of this request according to the multipart information. @@ -162,9 +158,7 @@ def _prepare_multipart_body_helper( eol = b"\r\n" _, _, body = full_message.split(eol, 2) http_request.set_bytes_body(body) - http_request.headers["Content-Type"] = ( - "multipart/mixed; boundary=" + main_message.get_boundary() - ) + http_request.headers["Content-Type"] = "multipart/mixed; boundary=" + main_message.get_boundary() return content_index @@ -231,16 +225,12 @@ def _decode_parts_helper( elif content_type == "multipart/mixed" and requests[index].multipart_mixed_info: # The message batch contains one or more change sets changeset_requests = requests[index].multipart_mixed_info[0] # type: ignore - changeset_responses = ( - response._decode_parts( # pylint: disable=protected-access - raw_response, http_response_type, changeset_requests - ) + changeset_responses = response._decode_parts( # pylint: disable=protected-access + raw_response, http_response_type, changeset_requests ) responses.extend(changeset_responses) else: - raise ValueError( - "Multipart doesn't support part other than application/http for now" - ) + raise ValueError("Multipart doesn't support part other than application/http for now") return responses @@ -254,17 +244,10 @@ def _get_raw_parts_helper(response, http_response_type): """ body_as_bytes = response.body() # In order to use email.message parser, I need full HTTP bytes. Faking something to make the parser happy - http_body = ( - b"Content-Type: " - + response.content_type.encode("ascii") - + b"\r\n\r\n" - + body_as_bytes - ) + http_body = b"Content-Type: " + response.content_type.encode("ascii") + b"\r\n\r\n" + body_as_bytes message: Message = message_parser(http_body) requests = response.request.multipart_mixed_info[0] - return response._decode_parts( # pylint: disable=protected-access - message, http_response_type, requests - ) + return response._decode_parts(message, http_response_type, requests) # pylint: disable=protected-access def _parts_helper( @@ -275,9 +258,7 @@ def _parts_helper( :rtype: iterator[HttpResponse] :raises ValueError: If the content is not multipart/mixed """ - if not response.content_type or not response.content_type.startswith( - "multipart/mixed" - ): + if not response.content_type or not response.content_type.startswith("multipart/mixed"): raise ValueError("You can't get parts if the response is not multipart/mixed") responses = response._get_raw_parts() # pylint: disable=protected-access @@ -291,9 +272,7 @@ def parse_responses(response): http_request = response.request context = PipelineContext(None) pipeline_request = PipelineRequest(http_request, context) - pipeline_response = PipelineResponse( - http_request, response, context=context - ) + pipeline_response = PipelineResponse(http_request, response, context=context) for policy in policies: _await_result(policy.on_response, pipeline_request, pipeline_response) @@ -307,9 +286,7 @@ def parse_responses(response): return responses -def _format_data_helper( - data: Union[str, IO] -) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]: +def _format_data_helper(data: Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]: """Helper for _format_data. Format field data according to whether it is a stream or @@ -343,9 +320,7 @@ def _aiohttp_body_helper( :rtype: bytes """ if response._content is None: - raise ValueError( - "Body is not available. Call async method load_body, or do your call with stream=False." - ) + raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.") if not response._decompress: return response._content if response._decompressed_content: diff --git a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py index b30b7c02349b..3d03f4b93771 100644 --- a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py +++ b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared_async.py @@ -30,22 +30,16 @@ async def _parse_response(self): http_response_type=self._default_http_response_type ) if self._response.request.multipart_mixed_info: - policies: List[ - "SansIOHTTPPolicy" - ] = self._response.request.multipart_mixed_info[1] + policies: List["SansIOHTTPPolicy"] = self._response.request.multipart_mixed_info[1] async def parse_responses(response): http_request = response.request context = PipelineContext(None) pipeline_request = PipelineRequest(http_request, context) - pipeline_response = PipelineResponse( - http_request, response, context=context - ) + pipeline_response = PipelineResponse(http_request, response, context=context) for policy in policies: - await _await_result( - policy.on_response, pipeline_request, pipeline_response - ) + await _await_result(policy.on_response, pipeline_request, pipeline_response) # Not happy to make this code asyncio specific, but that's multipart only for now # If we need trio and multipart, let's reinvesitgate that later diff --git a/sdk/core/azure-core/azure/core/utils/_utils.py b/sdk/core/azure-core/azure/core/utils/_utils.py index fc4a97433c06..134d79be781c 100644 --- a/sdk/core/azure-core/azure/core/utils/_utils.py +++ b/sdk/core/azure-core/azure/core/utils/_utils.py @@ -105,9 +105,7 @@ class CaseInsensitiveDict(MutableMapping[str, Any]): """ def __init__( - self, - data: Optional[Union[Mapping[str, Any], Iterable[Tuple[str, Any]]]] = None, - **kwargs: Any + self, data: Optional[Union[Mapping[str, Any], Iterable[Tuple[str, Any]]]] = None, **kwargs: Any ) -> None: self._store: Dict[str, Any] = {} if data is None: @@ -137,9 +135,7 @@ def __len__(self) -> int: return len(self._store) def lowerkey_items(self): - return ( - (lower_case_key, pair[1]) for lower_case_key, pair in self._store.items() - ) + return ((lower_case_key, pair[1]) for lower_case_key, pair in self._store.items()) def __eq__(self, other: Any) -> bool: if isinstance(other, Mapping): diff --git a/sdk/core/azure-core/samples/test_example_async.py b/sdk/core/azure-core/samples/test_example_async.py index 7f439e273ac9..2d4db41febc2 100644 --- a/sdk/core/azure-core/samples/test_example_async.py +++ b/sdk/core/azure-core/samples/test_example_async.py @@ -35,19 +35,16 @@ @pytest.mark.asyncio async def test_example_trio(): - async def req(): request = HttpRequest("GET", "https://bing.com/") - policies = [ - UserAgentPolicy("myuseragent"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), AsyncRedirectPolicy()] # [START trio] from azure.core.pipeline.transport import TrioRequestsTransport async with AsyncPipeline(TrioRequestsTransport(), policies=policies) as pipeline: return await pipeline.run(request) # [END trio] + response = trio.run(req) assert isinstance(response.http_response.status_code, int) @@ -56,10 +53,7 @@ async def req(): async def test_example_asyncio(): request = HttpRequest("GET", "https://bing.com") - policies = [ - UserAgentPolicy("myuseragent"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), AsyncRedirectPolicy()] # [START asyncio] from azure.core.pipeline.transport import AsyncioRequestsTransport @@ -74,10 +68,7 @@ async def test_example_asyncio(): async def test_example_aiohttp(): request = HttpRequest("GET", "https://bing.com") - policies = [ - UserAgentPolicy("myuseragent"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), AsyncRedirectPolicy()] # [START aiohttp] from azure.core.pipeline.transport import AioHttpTransport @@ -97,10 +88,7 @@ async def test_example_async_pipeline(): # example: create request and policies request = HttpRequest("GET", "https://bing.com") - policies = [ - UserAgentPolicy("myuseragent"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), AsyncRedirectPolicy()] # run the pipeline async with AsyncPipeline(transport=AioHttpTransport(), policies=policies) as pipeline: @@ -221,7 +209,7 @@ async def test_example_async_retry_policy(): retry_status=5, retry_backoff_factor=0.5, retry_backoff_max=60, - retry_on_methods=['GET'] + retry_on_methods=["GET"], ) # [END async_retry_policy] diff --git a/sdk/core/azure-core/samples/test_example_policies.py b/sdk/core/azure-core/samples/test_example_policies.py index cbf1a9933c2b..50dd93ff0908 100644 --- a/sdk/core/azure-core/samples/test_example_policies.py +++ b/sdk/core/azure-core/samples/test_example_policies.py @@ -24,6 +24,7 @@ # # -------------------------------------------------------------------------- + def test_example_raw_response_hook(): def callback(response): response.http_response.status_code = 200 @@ -35,9 +36,7 @@ def callback(response): from azure.core.pipeline.policies import CustomHookPolicy request = HttpRequest("GET", "https://bing.com") - policies = [ - CustomHookPolicy(raw_response_hook=callback) - ] + policies = [CustomHookPolicy(raw_response_hook=callback)] with Pipeline(transport=RequestsTransport(), policies=policies) as pipeline: response = pipeline.run(request) diff --git a/sdk/core/azure-core/samples/test_example_sansio.py b/sdk/core/azure-core/samples/test_example_sansio.py index 9006aaffa8b2..d63852c778cb 100644 --- a/sdk/core/azure-core/samples/test_example_sansio.py +++ b/sdk/core/azure-core/samples/test_example_sansio.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import sys from azure.core.pipeline.transport import HttpRequest @@ -35,16 +35,13 @@ def test_example_headers_policy(): url = "https://bing.com" - policies = [ - UserAgentPolicy("myuseragent"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), RedirectPolicy()] # [START headers_policy] from azure.core.pipeline.policies import HeadersPolicy headers_policy = HeadersPolicy() - headers_policy.add_header('CustomValue', 'Foo') + headers_policy.add_header("CustomValue", "Foo") # Or headers can be added per operation. These headers will supplement existing headers # or those defined in the config headers policy. They will also overwrite existing @@ -52,24 +49,22 @@ def test_example_headers_policy(): policies.append(headers_policy) client = PipelineClient(base_url=url, policies=policies) request = client.get(url) - pipeline_response = client._pipeline.run(request, headers={'CustomValue': 'Bar'}) + pipeline_response = client._pipeline.run(request, headers={"CustomValue": "Bar"}) # [END headers_policy] response = pipeline_response.http_response assert isinstance(response.status_code, int) + def test_example_request_id_policy(): url = "https://bing.com" - policies = [ - UserAgentPolicy("myuseragent"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), RedirectPolicy()] # [START request_id_policy] from azure.core.pipeline.policies import HeadersPolicy request_id_policy = RequestIdPolicy() - request_id_policy.set_request_id('azconfig-test') + request_id_policy.set_request_id("azconfig-test") # Or headers can be added per operation. These headers will supplement existing headers # or those defined in the config headers policy. They will also overwrite existing @@ -98,7 +93,7 @@ def test_example_user_agent_policy(): # You can also pass in a custom value per operation to append to the end of the user-agent. # This can be used together with the policy configuration to append multiple values. - policies=[ + policies = [ redirect_policy, user_agent_policy, ] @@ -114,10 +109,7 @@ def test_example_user_agent_policy(): def example_network_trace_logging(): filename = "log.txt" url = "https://bing.com" - policies = [ - UserAgentPolicy("myuseragent"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), RedirectPolicy()] # [START network_trace_logging_policy] from azure.core.pipeline.policies import NetworkTraceLoggingPolicy @@ -151,6 +143,7 @@ def example_network_trace_logging(): response = pipeline_response.http_response assert isinstance(response.status_code, int) + def example_proxy_policy(): # [START proxy_policy] @@ -159,10 +152,10 @@ def example_proxy_policy(): proxy_policy = ProxyPolicy() # Example - proxy_policy.proxies = {'http': 'http://10.10.1.10:3148'} + proxy_policy.proxies = {"http": "http://10.10.1.10:3148"} # Use basic auth - proxy_policy.proxies = {'https': 'http://user:password@10.10.1.10:1180/'} + proxy_policy.proxies = {"https": "http://user:password@10.10.1.10:1180/"} # You can also configure proxies by setting the environment variables # HTTP_PROXY and HTTPS_PROXY. diff --git a/sdk/core/azure-core/samples/test_example_sync.py b/sdk/core/azure-core/samples/test_example_sync.py index 071d1a643ef0..073ed9f4b84c 100644 --- a/sdk/core/azure-core/samples/test_example_sync.py +++ b/sdk/core/azure-core/samples/test_example_sync.py @@ -27,20 +27,12 @@ from azure.core.pipeline import Pipeline from azure.core import PipelineClient from azure.core.pipeline.transport import HttpRequest -from azure.core.pipeline.policies import ( - UserAgentPolicy, - RedirectPolicy, - RetryPolicy -) +from azure.core.pipeline.policies import UserAgentPolicy, RedirectPolicy, RetryPolicy def test_example_requests(): request = HttpRequest("GET", "https://bing.com") - policies = [ - UserAgentPolicy("myuseragent"), - RedirectPolicy(), - RetryPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), RedirectPolicy(), RetryPolicy()] # [START requests] from azure.core.pipeline.transport import RequestsTransport @@ -59,10 +51,7 @@ def test_example_pipeline(): # example: create request and policies request = HttpRequest("GET", "https://bing.com") - policies = [ - UserAgentPolicy("myuseragent"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), RedirectPolicy()] # run the pipeline with Pipeline(transport=RequestsTransport(), policies=policies) as pipeline: @@ -79,10 +68,7 @@ def test_example_pipeline_client(): from azure.core.pipeline.policies import RedirectPolicy, UserAgentPolicy # example configuration with some policies - policies = [ - UserAgentPolicy("myuseragent"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), RedirectPolicy()] client = PipelineClient(base_url=url, policies=policies) request = client.get("https://bing.com") @@ -93,6 +79,7 @@ def test_example_pipeline_client(): response = pipeline_response.http_response assert isinstance(response.status_code, int) + def test_example_redirect_policy(): url = "https://bing.com" @@ -139,10 +126,7 @@ def test_example_retry_policy(): url = "https://bing.com" - policies = [ - UserAgentPolicy("myuseragent"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myuseragent"), RedirectPolicy()] # [START retry_policy] from azure.core.pipeline.policies import RetryPolicy @@ -193,13 +177,14 @@ def test_example_retry_policy(): retry_status=5, retry_backoff_factor=0.5, retry_backoff_max=120, - retry_on_methods=['GET'] + retry_on_methods=["GET"], ) # [END retry_policy] response = pipeline_response.http_response assert isinstance(response.status_code, int) + def test_example_no_retries(): url = "https://bing.com" diff --git a/sdk/core/azure-core/setup.py b/sdk/core/azure-core/setup.py index f882a330ce7e..8cd4d45c1be6 100644 --- a/sdk/core/azure-core/setup.py +++ b/sdk/core/azure-core/setup.py @@ -1,10 +1,10 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import re import os.path @@ -16,59 +16,60 @@ PACKAGE_PPRINT_NAME = "Core" # a-b-c => a/b/c -package_folder_path = PACKAGE_NAME.replace('-', '/') +package_folder_path = PACKAGE_NAME.replace("-", "/") # a-b-c => a.b.c -namespace_name = PACKAGE_NAME.replace('-', '.') +namespace_name = PACKAGE_NAME.replace("-", ".") # Version extraction inspired from 'requests' -with open(os.path.join(package_folder_path, '_version.py'), 'r') as fd: - version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', # type: ignore - fd.read(), re.MULTILINE).group(1) +with open(os.path.join(package_folder_path, "_version.py"), "r") as fd: + version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) # type: ignore if not version: - raise RuntimeError('Cannot find version information') + raise RuntimeError("Cannot find version information") -with open('README.md', encoding='utf-8') as f: +with open("README.md", encoding="utf-8") as f: readme = f.read() -with open('CHANGELOG.md', encoding='utf-8') as f: +with open("CHANGELOG.md", encoding="utf-8") as f: changelog = f.read() setup( name=PACKAGE_NAME, version=version, include_package_data=True, - description='Microsoft Azure {} Library for Python'.format(PACKAGE_PPRINT_NAME), - long_description=readme + '\n\n' + changelog, - long_description_content_type='text/markdown', - license='MIT License', - author='Microsoft Corporation', - author_email='azpysdkhelp@microsoft.com', - url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-core', + description="Microsoft Azure {} Library for Python".format(PACKAGE_PPRINT_NAME), + long_description=readme + "\n\n" + changelog, + long_description_content_type="text/markdown", + license="MIT License", + author="Microsoft Corporation", + author_email="azpysdkhelp@microsoft.com", + url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-core", classifiers=[ "Development Status :: 5 - Production/Stable", - 'Programming Language :: Python', + "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: MIT License', + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: MIT License", ], zip_safe=False, - packages=find_packages(exclude=[ - 'tests', - # Exclude packages that will be covered by PEP420 or nspkg - 'azure', - ]), + packages=find_packages( + exclude=[ + "tests", + # Exclude packages that will be covered by PEP420 or nspkg + "azure", + ] + ), package_data={ - 'pytyped': ['py.typed'], + "pytyped": ["py.typed"], }, python_requires=">=3.7", install_requires=[ - 'requests>=2.18.4', - 'six>=1.11.0', + "requests>=2.18.4", + "six>=1.11.0", "typing-extensions>=4.0.1", ], extras_require={ diff --git a/sdk/core/azure-core/tests/async_tests/conftest.py b/sdk/core/azure-core/tests/async_tests/conftest.py index 99b7a90b0918..329c8970efd0 100644 --- a/sdk/core/azure-core/tests/async_tests/conftest.py +++ b/sdk/core/azure-core/tests/async_tests/conftest.py @@ -33,6 +33,7 @@ import urllib from rest_client_async import AsyncTestRestClient + def is_port_available(port_num): req = urllib.request.Request("http://localhost:{}/health".format(port_num)) try: @@ -40,6 +41,7 @@ def is_port_available(port_num): except Exception as e: return True + def get_port(): count = 3 for _ in range(count): @@ -48,19 +50,21 @@ def get_port(): return port_num raise TypeError("Tried {} times, can't find an open port".format(count)) + @pytest.fixture def port(): return os.environ["FLASK_PORT"] + def start_testserver(): port = get_port() os.environ["FLASK_APP"] = "coretestserver" os.environ["FLASK_PORT"] = str(port) cmd = "flask run -p {}".format(port) - if os.name == 'nt': #On windows, subprocess creation works without being in the shell + if os.name == "nt": # On windows, subprocess creation works without being in the shell child_process = subprocess.Popen(cmd, env=dict(os.environ)) else: - #On linux, have to set shell=True + # On linux, have to set shell=True child_process = subprocess.Popen(cmd, shell=True, preexec_fn=os.setsid, env=dict(os.environ)) count = 5 for _ in range(count): @@ -69,12 +73,14 @@ def start_testserver(): time.sleep(1) raise ValueError("Didn't start!") + def terminate_testserver(process): - if os.name == 'nt': + if os.name == "nt": process.kill() else: os.killpg(os.getpgid(process.pid), signal.SIGTERM) # Send the signal to all the process groups + @pytest.fixture(autouse=True, scope="package") def testserver(): """Start the Autorest testserver.""" @@ -82,6 +88,7 @@ def testserver(): yield terminate_testserver(server) + @pytest.fixture def client(port): return AsyncTestRestClient(port) diff --git a/sdk/core/azure-core/tests/async_tests/rest_client_async.py b/sdk/core/azure-core/tests/async_tests/rest_client_async.py index 2ad83c2dea4c..ce0b08be957c 100644 --- a/sdk/core/azure-core/tests/async_tests/rest_client_async.py +++ b/sdk/core/azure-core/tests/async_tests/rest_client_async.py @@ -8,10 +8,9 @@ from azure.core.pipeline import policies from azure.core.configuration import Configuration + class TestRestClientConfiguration(Configuration): - def __init__( - self, **kwargs - ): + def __init__(self, **kwargs): # type: (...) -> None super(TestRestClientConfiguration, self).__init__(**kwargs) @@ -29,16 +28,12 @@ def _configure(self, **kwargs) -> None: self.redirect_policy = kwargs.get("redirect_policy") or policies.AsyncRedirectPolicy(**kwargs) self.authentication_policy = kwargs.get("authentication_policy") -class AsyncTestRestClient(object): +class AsyncTestRestClient(object): def __init__(self, port, **kwargs): self._config = TestRestClientConfiguration(**kwargs) - self._client = AsyncPipelineClient( - base_url="http://localhost:{}".format(port), - config=self._config, - **kwargs - ) + self._client = AsyncPipelineClient(base_url="http://localhost:{}".format(port), config=self._config, **kwargs) def send_request(self, request, **kwargs): """Runs the network request through the client's chained policies. diff --git a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py index 2120b8efb018..3c676a2e97aa 100644 --- a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py @@ -17,6 +17,7 @@ pytestmark = pytest.mark.asyncio from utils import HTTP_REQUESTS + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_bearer_policy_adds_header(http_request): """The bearer token policy should add a header containing a token from its credential""" @@ -209,6 +210,7 @@ async def fake_send(*args, **kwargs): fake_send.calls = 1 return Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) raise TestException() + fake_send.calls = 0 policy = TestPolicy(credential, "scope") diff --git a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py index f5cdf7425cae..e737b8804f11 100644 --- a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import base64 import json import pickle @@ -31,6 +31,7 @@ from azure.core.pipeline._tools import is_rest import types import unittest + try: from unittest import mock except ImportError: @@ -52,6 +53,7 @@ from utils import ASYNCIO_REQUESTS_TRANSPORT_RESPONSES, request_and_responses_product, create_transport_response from rest_client_async import AsyncTestRestClient + class SimpleResource: """An implementation of Python 3 SimpleNamespace. Used to deserialize resource objects from response bodies where @@ -69,25 +71,31 @@ def __repr__(self): def __eq__(self, other): return self.__dict__ == other.__dict__ + class BadEndpointError(Exception): pass -TEST_NAME = 'foo' -RESPONSE_BODY = {'properties':{'provisioningState': 'InProgress'}} -ASYNC_BODY = json.dumps({ 'status': 'Succeeded' }) -ASYNC_URL = 'http://dummyurlFromAzureAsyncOPHeader_Return200' -LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) -LOCATION_URL = 'http://dummyurlurlFromLocationHeader_Return200' -RESOURCE_BODY = json.dumps({ 'name': TEST_NAME }) -RESOURCE_URL = 'http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1' -ERROR = 'http://dummyurl_ReturnError' + +TEST_NAME = "foo" +RESPONSE_BODY = {"properties": {"provisioningState": "InProgress"}} +ASYNC_BODY = json.dumps({"status": "Succeeded"}) +ASYNC_URL = "http://dummyurlFromAzureAsyncOPHeader_Return200" +LOCATION_BODY = json.dumps({"name": TEST_NAME}) +LOCATION_URL = "http://dummyurlurlFromLocationHeader_Return200" +RESOURCE_BODY = json.dumps({"name": TEST_NAME}) +RESOURCE_URL = "http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1" +ERROR = "http://dummyurl_ReturnError" POLLING_STATUS = 200 CLIENT = AsyncPipelineClient("http://example.org") CLIENT.http_request_type = None CLIENT.http_response_type = None + + async def mock_run(client_self, request, **kwargs): return TestBasePolling.mock_update(client_self.http_request_type, client_self.http_response_type, request.url) + + CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -103,21 +111,23 @@ def async_pipeline_client_builder(): send will receive "request" and kwargs as any transport layer """ + def create_client(send_cb): class TestHttpTransport(AsyncHttpTransport): - async def open(self): pass - async def close(self): pass - async def __aexit__(self, *args, **kwargs): pass + async def open(self): + pass + + async def close(self): + pass + + async def __aexit__(self, *args, **kwargs): + pass async def send(self, request, **kwargs): return await send_cb(request, **kwargs) - return AsyncPipelineClient( - 'http://example.org/', - pipeline=AsyncPipeline( - transport=TestHttpTransport() - ) - ) + return AsyncPipelineClient("http://example.org/", pipeline=AsyncPipeline(transport=TestHttpTransport())) + return create_client @@ -125,6 +135,7 @@ async def send(self, request, **kwargs): def deserialization_cb(): def cb(pipeline_response): return json.loads(pipeline_response.http_response.text()) + return cb @@ -143,7 +154,7 @@ def polling_response(): None, response, ), - PipelineContext(None) + PipelineContext(None), ) polling._initial_response = polling._pipeline_response return polling, headers @@ -165,145 +176,116 @@ def test_base_polling_continuation_token(client, polling_response): @pytest.mark.asyncio -@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +) async def test_post(async_pipeline_client_builder, deserialization_cb, http_request, http_response): - # Test POST LRO with both Location and Operation-Location - - # The initial response contains both Location and Operation-Location, a 202 and no Body - initial_response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', - 202, - { - 'location': 'http://example.org/location', - 'operation-location': 'http://example.org/async_monitor', - }, - '' - ) + # Test POST LRO with both Location and Operation-Location + + # The initial response contains both Location and Operation-Location, a 202 and no Body + initial_response = TestBasePolling.mock_send( + http_request, + http_response, + "POST", + 202, + { + "location": "http://example.org/location", + "operation-location": "http://example.org/async_monitor", + }, + "", + ) + + async def send(request, **kwargs): + assert request.method == "GET" + + if request.url == "http://example.org/location": + return TestBasePolling.mock_send( + http_request, http_response, "GET", 200, body={"location_result": True} + ).http_response + elif request.url == "http://example.org/async_monitor": + return TestBasePolling.mock_send( + http_request, http_response, "GET", 200, body={"status": "Succeeded"} + ).http_response + else: + pytest.fail("No other query allowed") + + client = async_pipeline_client_builder(send) + + # LRO options with Location final state + poll = async_poller(client, initial_response, deserialization_cb, AsyncLROBasePolling(0)) + result = await poll + assert result["location_result"] == True + + # Location has no body + + async def send(request, **kwargs): + assert request.method == "GET" + + if request.url == "http://example.org/location": + return TestBasePolling.mock_send(http_request, http_response, "GET", 200, body=None).http_response + elif request.url == "http://example.org/async_monitor": + return TestBasePolling.mock_send( + http_request, http_response, "GET", 200, body={"status": "Succeeded"} + ).http_response + else: + pytest.fail("No other query allowed") - async def send(request, **kwargs): - assert request.method == 'GET' - - if request.url == 'http://example.org/location': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'location_result': True} - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded'} - ).http_response - else: - pytest.fail("No other query allowed") - - client = async_pipeline_client_builder(send) - - # LRO options with Location final state - poll = async_poller( - client, - initial_response, - deserialization_cb, - AsyncLROBasePolling(0)) - result = await poll - assert result['location_result'] == True - - # Location has no body - - async def send(request, **kwargs): - assert request.method == 'GET' - - if request.url == 'http://example.org/location': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body=None - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded'} - ).http_response - else: - pytest.fail("No other query allowed") - - client = async_pipeline_client_builder(send) - - poll = async_poller( - client, - initial_response, - deserialization_cb, - AsyncLROBasePolling(0)) - result = await poll - assert result is None + client = async_pipeline_client_builder(send) + + poll = async_poller(client, initial_response, deserialization_cb, AsyncLROBasePolling(0)) + result = await poll + assert result is None @pytest.mark.asyncio -@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +) async def test_post_resource_location(async_pipeline_client_builder, deserialization_cb, http_request, http_response): - # ResourceLocation - - # The initial response contains both Location and Operation-Location, a 202 and no Body - initial_response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', - 202, - { - 'operation-location': 'http://example.org/async_monitor', - }, - '' - ) + # ResourceLocation + + # The initial response contains both Location and Operation-Location, a 202 and no Body + initial_response = TestBasePolling.mock_send( + http_request, + http_response, + "POST", + 202, + { + "operation-location": "http://example.org/async_monitor", + }, + "", + ) + + async def send(request, **kwargs): + assert request.method == "GET" + + if request.url == "http://example.org/resource_location": + return TestBasePolling.mock_send( + http_request, http_response, "GET", 200, body={"location_result": True} + ).http_response + elif request.url == "http://example.org/async_monitor": + return TestBasePolling.mock_send( + http_request, + http_response, + "GET", + 200, + body={"status": "Succeeded", "resourceLocation": "http://example.org/resource_location"}, + ).http_response + else: + pytest.fail("No other query allowed") + + client = async_pipeline_client_builder(send) + + poll = async_poller(client, initial_response, deserialization_cb, AsyncLROBasePolling(0)) + result = await poll + assert result["location_result"] == True - async def send(request, **kwargs): - assert request.method == 'GET' - - if request.url == 'http://example.org/resource_location': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'location_result': True} - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded', 'resourceLocation': 'http://example.org/resource_location'} - ).http_response - else: - pytest.fail("No other query allowed") - - client = async_pipeline_client_builder(send) - - poll = async_poller( - client, - initial_response, - deserialization_cb, - AsyncLROBasePolling(0)) - result = await poll - assert result['location_result'] == True class TestBasePolling(object): - convert = re.compile('([a-z0-9])([A-Z])') + convert = re.compile("([a-z0-9])([A-Z])") @staticmethod def mock_send(http_request, http_response, method, status, headers=None, body=RESPONSE_BODY): @@ -311,13 +293,11 @@ def mock_send(http_request, http_response, method, status, headers=None, body=RE headers = {} response = Response() response._content_consumed = True - response._content = json.dumps(body).encode('ascii') if body is not None else None + response._content = json.dumps(body).encode("ascii") if body is not None else None response.request = Request() response.request.method = method response.request.url = RESOURCE_URL - response.request.headers = { - 'x-ms-client-request-id': '67f4dd4e-6262-45e1-8bed-5c45cf23b6d9' - } + response.request.headers = {"x-ms-client-request-id": "67f4dd4e-6262-45e1-8bed-5c45cf23b6d9"} response.status_code = status response.headers = headers response.headers.update({"content-type": "application/json; charset=utf8"}) @@ -338,23 +318,19 @@ def mock_send(http_request, http_response, method, status, headers=None, body=RE response.request.headers, body, None, # form_content - None # stream_content + None, # stream_content ) response = create_transport_response(http_response, request, response) if is_rest(http_response): response.body() - return PipelineResponse( - request, - response, - None # context - ) + return PipelineResponse(request, response, None) # context @staticmethod def mock_update(http_request, http_response, url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) - response.request.method = 'GET' + response.request.method = "GET" response.headers = headers or {} response.headers.update({"content-type": "application/json; charset=utf8"}) response.reason = "OK" @@ -362,13 +338,13 @@ def mock_update(http_request, http_response, url, headers=None): if url == ASYNC_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = ASYNC_BODY.encode('ascii') + response._content = ASYNC_BODY.encode("ascii") response.randomFieldFromPollAsyncOpHeader = None elif url == LOCATION_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = LOCATION_BODY.encode('ascii') + response._content = LOCATION_BODY.encode("ascii") response.randomFieldFromPollLocationHeader = None elif url == ERROR: @@ -377,10 +353,10 @@ def mock_update(http_request, http_response, url, headers=None): elif url == RESOURCE_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = RESOURCE_BODY.encode('ascii') + response._content = RESOURCE_BODY.encode("ascii") else: - raise Exception('URL does not match') + raise Exception("URL does not match") request = http_request( response.request.method, @@ -391,11 +367,7 @@ def mock_update(http_request, http_response, url, headers=None): if is_rest(http_response): response.body() - return PipelineResponse( - request, - response, - None # context - ) + return PipelineResponse(request, response, None) # context @staticmethod def mock_outputs(pipeline_response): @@ -405,15 +377,13 @@ def mock_outputs(pipeline_response): except ValueError: raise DecodeError("Impossible to deserialize") - body = {TestBasePolling.convert.sub(r'\1_\2', k).lower(): v - for k, v in body.items()} - properties = body.setdefault('properties', {}) - if 'name' in body: - properties['name'] = body['name'] + body = {TestBasePolling.convert.sub(r"\1_\2", k).lower(): v for k, v in body.items()} + properties = body.setdefault("properties", {}) + if "name" in body: + properties["name"] = body["name"] if properties: - properties = {TestBasePolling.convert.sub(r'\1_\2', k).lower(): v - for k, v in properties.items()} - del body['properties'] + properties = {TestBasePolling.convert.sub(r"\1_\2", k).lower(): v for k, v in properties.items()} + del body["properties"] body.update(properties) resource = SimpleResource(**body) else: @@ -423,110 +393,74 @@ def mock_outputs(pipeline_response): @staticmethod def mock_deserialization_no_body(pipeline_response): - """Use this mock when you don't expect a return (last body irrelevant) - """ + """Use this mock when you don't expect a return (last body irrelevant)""" return None + @pytest.mark.asyncio -@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +) async def test_long_running_put(http_request, http_response): - #TODO: Test custom header field + # TODO: Test custom header field CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response # Test throw on non LRO related status code - response = TestBasePolling.mock_send( - http_request, http_response, 'PUT', 1000, {} - ) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 1000, {}) with pytest.raises(HttpResponseError): - await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) + await async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) # Test with no polling necessary - response_body = { - 'properties':{'provisioningState': 'Succeeded'}, - 'name': TEST_NAME - } - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {}, response_body - ) + response_body = {"properties": {"provisioningState": "Succeeded"}, "name": TEST_NAME} + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {}, response_body) + def no_update_allowed(url, headers=None): raise ValueError("Should not try to update") + polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method - ) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME - assert not hasattr(polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling from operation-location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'operation-location': ASYNC_URL}) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {"operation-location": ASYNC_URL}) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME - assert not hasattr(polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'location': LOCATION_URL}) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {"location": LOCATION_URL}) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None # Test polling initial payload invalid (SQLDb) response_body = {} # Empty will raise response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'location': LOCATION_URL}, response_body) + http_request, http_response, "PUT", 201, {"location": LOCATION_URL}, response_body + ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None # Test fail to poll from operation-location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'operation-location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {"operation-location": ERROR}) with pytest.raises(BadEndpointError): - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) # Test fail to poll from location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) + @pytest.mark.asyncio -@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +) async def test_long_running_patch(http_request, http_response): CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response @@ -534,13 +468,13 @@ async def test_long_running_patch(http_request, http_response): response = TestBasePolling.mock_send( http_request, http_response, - 'PATCH', 202, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "PATCH", + 202, + {"location": LOCATION_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None @@ -548,27 +482,27 @@ async def test_long_running_patch(http_request, http_response): response = TestBasePolling.mock_send( http_request, http_response, - 'PATCH', 202, - {'operation-location': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "PATCH", + 202, + {"operation-location": ASYNC_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME - assert not hasattr(polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling from location header response = TestBasePolling.mock_send( http_request, http_response, - 'PATCH', 200, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "PATCH", + 200, + {"location": LOCATION_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None @@ -576,60 +510,48 @@ async def test_long_running_patch(http_request, http_response): response = TestBasePolling.mock_send( http_request, http_response, - 'PATCH', 200, - {'operation-location': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "PATCH", + 200, + {"operation-location": ASYNC_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME - assert not hasattr(polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test fail to poll from operation-location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PATCH', 202, - {'operation-location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "PATCH", 202, {"operation-location": ERROR}) with pytest.raises(BadEndpointError): - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) # Test fail to poll from location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PATCH', 202, - {'location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "PATCH", 202, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) + @pytest.mark.asyncio -@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +) async def test_long_running_delete(http_request, http_response): # Test polling from operation-location header CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response response = TestBasePolling.mock_send( - http_request, - http_response, - 'DELETE', 202, - {'operation-location': ASYNC_URL}, - body="" + http_request, http_response, "DELETE", 202, {"operation-location": ASYNC_URL}, body="" ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_deserialization_no_body, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, polling_method) assert poll is None assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None + @pytest.mark.asyncio -@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +) async def test_long_running_post(http_request, http_response): CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response @@ -637,118 +559,91 @@ async def test_long_running_post(http_request, http_response): response = TestBasePolling.mock_send( http_request, http_response, - 'POST', 201, - {'operation-location': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "POST", + 201, + {"operation-location": ASYNC_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_deserialization_no_body, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, polling_method) assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, http_response, - 'POST', 202, - {'operation-location': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "POST", + 202, + {"operation-location": ASYNC_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_deserialization_no_body, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, polling_method) assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None # Test polling from location header response = TestBasePolling.mock_send( http_request, http_response, - 'POST', 202, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "POST", + 202, + {"location": LOCATION_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) polling_method = AsyncLROBasePolling(0) - poll = await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestBasePolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None # Test fail to poll from operation-location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'operation-location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"operation-location": ERROR}) with pytest.raises(BadEndpointError): - await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) + await async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) # Test fail to poll from location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"location": ERROR}) with pytest.raises(BadEndpointError): - await async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) + await async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) + @pytest.mark.asyncio -@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +) async def test_long_running_negative(http_request, http_response): global LOCATION_BODY global POLLING_STATUS CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response # Test LRO PUT throws for invalid json - LOCATION_BODY = '{' - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'location': LOCATION_URL}) - poll = async_poller( - CLIENT, - response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0) - ) + LOCATION_BODY = "{" + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"location": LOCATION_URL}) + poll = async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) with pytest.raises(DecodeError): await poll - LOCATION_BODY = '{\'"}' - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'location': LOCATION_URL}) - poll = async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) + LOCATION_BODY = "{'\"}" + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"location": LOCATION_URL}) + poll = async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) with pytest.raises(DecodeError): await poll - LOCATION_BODY = '{' + LOCATION_BODY = "{" POLLING_STATUS = 203 - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'location': LOCATION_URL}) - poll = async_poller(CLIENT, response, - TestBasePolling.mock_outputs, - AsyncLROBasePolling(0)) - with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"location": LOCATION_URL}) + poll = async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0)) + with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization await poll - assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode('ascii') + assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii") - LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) + LOCATION_BODY = json.dumps({"name": TEST_NAME}) POLLING_STATUS = 200 + @pytest.mark.asyncio -@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response", request_and_responses_product(ASYNCIO_REQUESTS_TRANSPORT_RESPONSES) +) async def test_post_final_state_via(async_pipeline_client_builder, deserialization_cb, http_request, http_response): # Test POST LRO with both Location and Operation-Location CLIENT.http_request_type = http_request @@ -757,33 +652,25 @@ async def test_post_final_state_via(async_pipeline_client_builder, deserializati initial_response = TestBasePolling.mock_send( http_request, http_response, - 'POST', + "POST", 202, { - 'location': 'http://example.org/location', - 'operation-location': 'http://example.org/async_monitor', + "location": "http://example.org/location", + "operation-location": "http://example.org/async_monitor", }, - '' + "", ) async def send(request, **kwargs): - assert request.method == 'GET' + assert request.method == "GET" - if request.url == 'http://example.org/location': + if request.url == "http://example.org/location": return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'location_result': True} + http_request, http_response, "GET", 200, body={"location_result": True} ).http_response - elif request.url == 'http://example.org/async_monitor': + elif request.url == "http://example.org/async_monitor": return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded'} + http_request, http_response, "GET", 200, body={"status": "Succeeded"} ).http_response else: pytest.fail("No other query allowed") @@ -795,48 +682,36 @@ async def send(request, **kwargs): client, initial_response, deserialization_cb, - AsyncLROBasePolling(0, lro_options={"final-state-via": "location"})) + AsyncLROBasePolling(0, lro_options={"final-state-via": "location"}), + ) result = await poll - assert result['location_result'] == True + assert result["location_result"] == True # Test 2, LRO options with Operation-Location final state poll = async_poller( client, initial_response, deserialization_cb, - AsyncLROBasePolling(0, lro_options={"final-state-via": "operation-location"})) + AsyncLROBasePolling(0, lro_options={"final-state-via": "operation-location"}), + ) result = await poll - assert result['status'] == 'Succeeded' + assert result["status"] == "Succeeded" # Test 3, "do the right thing" and use Location by default - poll = async_poller( - client, - initial_response, - deserialization_cb, - AsyncLROBasePolling(0)) + poll = async_poller(client, initial_response, deserialization_cb, AsyncLROBasePolling(0)) result = await poll - assert result['location_result'] == True + assert result["location_result"] == True # Test 4, location has no body async def send(request, **kwargs): - assert request.method == 'GET' + assert request.method == "GET" - if request.url == 'http://example.org/location': + if request.url == "http://example.org/location": + return TestBasePolling.mock_send(http_request, http_response, "GET", 200, body=None).http_response + elif request.url == "http://example.org/async_monitor": return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body=None - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded'} + http_request, http_response, "GET", 200, body={"status": "Succeeded"} ).http_response else: pytest.fail("No other query allowed") @@ -847,10 +722,12 @@ async def send(request, **kwargs): client, initial_response, deserialization_cb, - AsyncLROBasePolling(0, lro_options={"final-state-via": "location"})) + AsyncLROBasePolling(0, lro_options={"final-state-via": "location"}), + ) result = await poll assert result is None + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_final_get_via_location(port, http_request, deserialization_cb): diff --git a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py index f68634652040..018aff29a251 100644 --- a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py @@ -3,7 +3,11 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- -from azure.core.pipeline.transport import AsyncHttpResponse as PipelineTransportAsyncHttpResponse, AsyncHttpTransport, AioHttpTransport +from azure.core.pipeline.transport import ( + AsyncHttpResponse as PipelineTransportAsyncHttpResponse, + AsyncHttpTransport, + AioHttpTransport, +) from azure.core.rest._http_response_impl_async import AsyncHttpResponseImpl as RestAsyncHttpResponse from azure.core.pipeline.policies import HeadersPolicy from azure.core.pipeline import AsyncPipeline @@ -17,12 +21,22 @@ # MagicMock support async cxt manager only after 3.8 # https://github.com/python/cpython/pull/9296 + class MockAsyncHttpTransport(AsyncHttpTransport): - async def __aenter__(self): return self - async def __aexit__(self, *args): pass - async def open(self): pass - async def close(self): pass - async def send(self, request, **kwargs): pass + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + async def open(self): + pass + + async def close(self): + pass + + async def send(self, request, **kwargs): + pass class PipelineTransportMockResponse(PipelineTransportAsyncHttpResponse): @@ -34,6 +48,7 @@ def __init__(self, request, body, content_type): def body(self): return self._body + class RestMockResponse(RestAsyncHttpResponse): def __init__(self, request, body, content_type): super(RestMockResponse, self).__init__( @@ -57,8 +72,10 @@ def body(self): def content(self): return self._content + MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_basic_options_aiohttp(port, http_request): @@ -79,7 +96,7 @@ async def test_multipart_send(http_request): class RequestPolicy(object): async def on_request(self, request): # type: (PipelineRequest) -> None - request.http_request.headers['x-ms-date'] = 'Thu, 14 Jun 2018 16:46:54 GMT' + request.http_request.headers["x-ms-date"] = "Thu, 14 Jun 2018 16:46:54 GMT" req0 = http_request("DELETE", "/container0/blob0") req1 = http_request("DELETE", "/container1/blob1") @@ -89,32 +106,32 @@ async def on_request(self, request): req0, req1, policies=[RequestPolicy()], - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" # Fix it so test are deterministic + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", # Fix it so test are deterministic ) async with AsyncPipeline(transport) as pipeline: await pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -128,7 +145,7 @@ async def test_multipart_send_with_context(http_request): class RequestPolicy(object): async def on_request(self, request): # type: (PipelineRequest) -> None - request.http_request.headers['x-ms-date'] = 'Thu, 14 Jun 2018 16:46:54 GMT' + request.http_request.headers["x-ms-date"] = "Thu, 14 Jun 2018 16:46:54 GMT" req0 = http_request("DELETE", "/container0/blob0") req1 = http_request("DELETE", "/container1/blob1") @@ -138,35 +155,35 @@ async def on_request(self, request): req0, req1, policies=[header_policy, RequestPolicy()], - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", # Fix it so test are deterministic - headers={'Accept': 'application/json'} + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", # Fix it so test are deterministic + headers={"Accept": "application/json"}, ) async with AsyncPipeline(transport) as pipeline: await pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'Accept: application/json\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'Accept: application/json\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"Accept: application/json\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"Accept: application/json\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -174,48 +191,39 @@ async def on_request(self, request): @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_multipart_send_with_one_changeset(http_request): transport = MockAsyncHttpTransport() - requests = [ - http_request("DELETE", "/container0/blob0"), - http_request("DELETE", "/container1/blob1") - ] + requests = [http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1")] changeset = http_request("", "") - changeset.set_multipart_mixed( - *requests, - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" - ) + changeset.set_multipart_mixed(*requests, boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525") request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed( - changeset, - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" - ) + request.set_multipart_mixed(changeset, boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525") async with AsyncPipeline(transport) as pipeline: await pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -227,13 +235,13 @@ async def test_multipart_send_with_multiple_changesets(http_request): changeset1.set_multipart_mixed( http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1"), - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) changeset2 = http_request("", "") changeset2.set_multipart_mixed( http_request("DELETE", "/container2/blob2"), http_request("DELETE", "/container3/blob3"), - boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314" + boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314", ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") @@ -247,49 +255,49 @@ async def test_multipart_send_with_multiple_changesets(http_request): await pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'DELETE /container2/blob2 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 3\r\n' - b'\r\n' - b'DELETE /container3/blob3 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"DELETE /container2/blob2 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 3\r\n" + b"\r\n" + b"DELETE /container3/blob3 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -302,49 +310,47 @@ async def test_multipart_send_with_combination_changeset_first(http_request): changeset.set_multipart_mixed( http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1"), - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - changeset, - http_request("DELETE", "/container2/blob2"), - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" + changeset, http_request("DELETE", "/container2/blob2"), boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) async with AsyncPipeline(transport) as pipeline: await pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'DELETE /container2/blob2 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"DELETE /container2/blob2 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -356,49 +362,47 @@ async def test_multipart_send_with_combination_changeset_last(http_request): changeset.set_multipart_mixed( http_request("DELETE", "/container1/blob1"), http_request("DELETE", "/container2/blob2"), - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - changeset, - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" + http_request("DELETE", "/container0/blob0"), changeset, boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) async with AsyncPipeline(transport) as pipeline: await pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'DELETE /container2/blob2 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"DELETE /container2/blob2 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -408,77 +412,71 @@ async def test_multipart_send_with_combination_changeset_middle(http_request): transport = MockAsyncHttpTransport() changeset = http_request("", "") changeset.set_multipart_mixed( - http_request("DELETE", "/container1/blob1"), - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + http_request("DELETE", "/container1/blob1"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( http_request("DELETE", "/container0/blob0"), changeset, http_request("DELETE", "/container2/blob2"), - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) async with AsyncPipeline(transport) as pipeline: await pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'DELETE /container2/blob2 HTTP/1.1\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"DELETE /container2/blob2 HTTP/1.1\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @pytest.mark.asyncio @pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) async def test_multipart_receive(http_request, mock_response): - class ResponsePolicy(object): def on_response(self, request, response): # type: (PipelineRequest, PipelineResponse) -> None - response.http_response.headers['x-ms-fun'] = 'true' + response.http_response.headers["x-ms-fun"] = "true" class AsyncResponsePolicy(object): async def on_response(self, request, response): # type: (PipelineRequest, PipelineResponse) -> None - response.http_response.headers['x-ms-async-fun'] = 'true' + response.http_response.headers["x-ms-async-fun"] = "true" req0 = http_request("DELETE", "/container0/blob0") req1 = http_request("DELETE", "/container1/blob1") request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed( - req0, - req1, - policies=[ResponsePolicy(), AsyncResponsePolicy()] - ) + request.set_multipart_mixed(req0, req1, policies=[ResponsePolicy(), AsyncResponsePolicy()]) body_as_str = ( "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" @@ -509,8 +507,8 @@ async def on_response(self, request, response): response = mock_response( request, - body_as_str.encode('ascii'), - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + body_as_str.encode("ascii"), + "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed", ) parts = [] @@ -521,13 +519,13 @@ async def on_response(self, request, response): res0 = parts[0] assert res0.status_code == 202 - assert res0.headers['x-ms-fun'] == 'true' - assert res0.headers['x-ms-async-fun'] == 'true' + assert res0.headers["x-ms-fun"] == "true" + assert res0.headers["x-ms-async-fun"] == "true" res1 = parts[1] assert res1.status_code == 404 - assert res1.headers['x-ms-fun'] == 'true' - assert res1.headers['x-ms-async-fun'] == 'true' + assert res1.headers["x-ms-fun"] == "true" + assert res1.headers["x-ms-async-fun"] == "true" @pytest.mark.asyncio @@ -535,45 +533,42 @@ async def on_response(self, request, response): async def test_multipart_receive_with_one_changeset(http_request, mock_response): changeset = http_request("", "") changeset.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - http_request("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1") ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 202 Accepted\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'HTTP/1.1 202 Accepted\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 202 Accepted\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"HTTP/1.1 202 Accepted\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -591,75 +586,71 @@ async def test_multipart_receive_with_multiple_changesets(http_request, mock_res changeset1 = http_request("", "") changeset1.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - http_request("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1") ) changeset2 = http_request("", "") changeset2.set_multipart_mixed( - http_request("DELETE", "/container2/blob2"), - http_request("DELETE", "/container3/blob3") + http_request("DELETE", "/container2/blob2"), http_request("DELETE", "/container3/blob3") ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset1, changeset2) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 200\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'HTTP/1.1 202\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 200\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"HTTP/1.1 202\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314"\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 404\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 3\r\n' - b'\r\n' - b'HTTP/1.1 409\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 404\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 3\r\n" + b"\r\n" + b"HTTP/1.1 409\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -678,55 +669,52 @@ async def test_multipart_receive_with_combination_changeset_first(http_request, changeset = http_request("", "") changeset.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - http_request("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1") ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset, http_request("DELETE", "/container2/blob2")) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 200\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'HTTP/1.1 202\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 404\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 200\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"HTTP/1.1 202\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 404\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -737,6 +725,7 @@ async def test_multipart_receive_with_combination_changeset_first(http_request, assert parts[1].status_code == 202 assert parts[2].status_code == 404 + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) def test_raise_for_status_bad_response(mock_response): response = mock_response(request=None, body=None, content_type=None) @@ -744,6 +733,7 @@ def test_raise_for_status_bad_response(mock_response): with pytest.raises(HttpResponseError): response.raise_for_status() + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) def test_raise_for_status_good_response(mock_response): response = mock_response(request=None, body=None, content_type=None) @@ -760,53 +750,49 @@ async def test_multipart_receive_with_combination_changeset_middle(http_request, request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - changeset, - http_request("DELETE", "/container2/blob2") + http_request("DELETE", "/container0/blob0"), changeset, http_request("DELETE", "/container2/blob2") ) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 200\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 200\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 202\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 404\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 202\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 404\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -824,56 +810,53 @@ async def test_multipart_receive_with_combination_changeset_last(http_request, m changeset = http_request("", "") changeset.set_multipart_mixed( - http_request("DELETE", "/container1/blob1"), - http_request("DELETE", "/container2/blob2") + http_request("DELETE", "/container1/blob1"), http_request("DELETE", "/container2/blob2") ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(http_request("DELETE", "/container0/blob0"), changeset) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 200\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 200\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 202\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'HTTP/1.1 404\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 202\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"HTTP/1.1 404\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -898,21 +881,19 @@ async def test_multipart_receive_with_bom(http_request, mock_response): b"Content-Type: application/http\n" b"Content-Transfer-Encoding: binary\n" b"Content-ID: 0\n" - b'\r\n' - b'HTTP/1.1 400 One of the request inputs is not valid.\r\n' - b'Content-Length: 220\r\n' - b'Content-Type: application/xml\r\n' - b'Server: Windows-Azure-Blob/1.0\r\n' - b'\r\n' + b"\r\n" + b"HTTP/1.1 400 One of the request inputs is not valid.\r\n" + b"Content-Length: 220\r\n" + b"Content-Type: application/xml\r\n" + b"Server: Windows-Azure-Blob/1.0\r\n" + b"\r\n" b'\xef\xbb\xbf\nInvalidInputOne' - b'of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z\n' + b"of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z\n" b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -922,7 +903,7 @@ async def test_multipart_receive_with_bom(http_request, mock_response): res0 = parts[0] assert res0.status_code == 400 - assert res0.body().startswith(b'\xef\xbb\xbf') + assert res0.body().startswith(b"\xef\xbb\xbf") @pytest.mark.asyncio @@ -960,8 +941,8 @@ async def test_recursive_multipart_receive(http_request, mock_response): response = mock_response( request, - body_as_str.encode('ascii'), - "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6" + body_as_str.encode("ascii"), + "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6", ) parts = [] @@ -986,6 +967,7 @@ async def test_recursive_multipart_receive(http_request, mock_response): def test_aiohttp_loop(): import asyncio from azure.core.pipeline.transport import AioHttpTransport + loop = asyncio.get_event_loop() with pytest.raises(ValueError): transport = AioHttpTransport(loop=loop) diff --git a/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py b/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py index e8e5b1b80236..0f28c70c2c4b 100644 --- a/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py @@ -9,27 +9,26 @@ import pytest import sys from unittest.mock import Mock -from azure.core.pipeline import ( - PipelineResponse, - PipelineRequest, - PipelineContext -) +from azure.core.pipeline import PipelineResponse, PipelineRequest, PipelineContext from azure.core.pipeline.policies import ( HttpLoggingPolicy, ) from utils import HTTP_RESPONSES, request_and_responses_product, create_http_response + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_http_logger(http_request, http_response): - class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() self.messages = [] + def reset(self): self.messages = [] + def emit(self, record): self.messages.append(record) + mock_handler = MockHandler() logger = logging.getLogger("testlogger") @@ -38,7 +37,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -49,16 +48,16 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") assert messages_request[0] == "Request URL: 'http://localhost/'" assert messages_request[1] == "Request method: 'GET'" - assert messages_request[2] == 'Request headers:' - assert messages_request[3] == 'No body was attached to the request' - assert messages_response[0] == 'Response status: 202' - assert messages_response[1] == 'Response headers:' + assert messages_request[2] == "Request headers:" + assert messages_request[3] == "No body was attached to the request" + assert messages_response[0] == "Response status: 202" + assert messages_response[1] == "Response headers:" mock_handler.reset() @@ -72,7 +71,7 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 4 messages_request1 = mock_handler.messages[0].message.split("\n") messages_response1 = mock_handler.messages[1].message.split("\n") @@ -80,22 +79,22 @@ def emit(self, record): messages_response2 = mock_handler.messages[3].message.split("\n") assert messages_request1[0] == "Request URL: 'http://localhost/'" assert messages_request1[1] == "Request method: 'GET'" - assert messages_request1[2] == 'Request headers:' - assert messages_request1[3] == 'No body was attached to the request' - assert messages_response1[0] == 'Response status: 202' - assert messages_response1[1] == 'Response headers:' + assert messages_request1[2] == "Request headers:" + assert messages_request1[3] == "No body was attached to the request" + assert messages_response1[0] == "Response status: 202" + assert messages_response1[1] == "Response headers:" assert messages_request2[0] == "Request URL: 'http://localhost/'" assert messages_request2[1] == "Request method: 'GET'" - assert messages_request2[2] == 'Request headers:' - assert messages_request2[3] == 'No body was attached to the request' - assert messages_response2[0] == 'Response status: 202' - assert messages_response2[1] == 'Response headers:' + assert messages_request2[2] == "Request headers:" + assert messages_request2[3] == "No body was attached to the request" + assert messages_response2[0] == "Response status: 202" + assert messages_response2[1] == "Response headers:" mock_handler.reset() # Headers and query parameters - policy.allowed_query_params = ['country'] + policy.allowed_query_params = ["country"] universal_request.headers = { "Accept": "Caramel", @@ -111,7 +110,7 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") @@ -119,39 +118,31 @@ def emit(self, record): assert messages_request[1] == "Request method: 'GET'" assert messages_request[2] == "Request headers:" # Dict not ordered in Python, exact logging order doesn't matter - assert set([ - messages_request[3], - messages_request[4] - ]) == set([ - " 'Accept': 'Caramel'", - " 'Hate': 'REDACTED'" - ]) - assert messages_request[5] == 'No body was attached to the request' + assert set([messages_request[3], messages_request[4]]) == set([" 'Accept': 'Caramel'", " 'Hate': 'REDACTED'"]) + assert messages_request[5] == "No body was attached to the request" assert messages_response[0] == "Response status: 202" assert messages_response[1] == "Response headers:" # Dict not ordered in Python, exact logging order doesn't matter - assert set([ - messages_response[2], - messages_response[3] - ]) == set([ - " 'Content-Type': 'Caramel'", - " 'HateToo': 'REDACTED'" - ]) + assert set([messages_response[2], messages_response[3]]) == set( + [" 'Content-Type': 'Caramel'", " 'HateToo': 'REDACTED'"] + ) mock_handler.reset() @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_http_logger_operation_level(http_request, http_response): - class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() self.messages = [] + def reset(self): self.messages = [] + def emit(self, record): self.messages.append(record) + mock_handler = MockHandler() logger = logging.getLogger("testlogger") @@ -159,9 +150,9 @@ def emit(self, record): logger.setLevel(logging.DEBUG) policy = HttpLoggingPolicy() - kwargs={'logger': logger} + kwargs = {"logger": logger} - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None, **kwargs)) @@ -172,16 +163,16 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") assert messages_request[0] == "Request URL: 'http://localhost/'" assert messages_request[1] == "Request method: 'GET'" - assert messages_request[2] == 'Request headers:' - assert messages_request[3] == 'No body was attached to the request' - assert messages_response[0] == 'Response status: 202' - assert messages_response[1] == 'Response headers:' + assert messages_request[2] == "Request headers:" + assert messages_request[3] == "No body was attached to the request" + assert messages_response[0] == "Response status: 202" + assert messages_response[1] == "Response headers:" mock_handler.reset() @@ -197,7 +188,7 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 4 messages_request1 = mock_handler.messages[0].message.split("\n") messages_response1 = mock_handler.messages[1].message.split("\n") @@ -205,30 +196,33 @@ def emit(self, record): messages_response2 = mock_handler.messages[3].message.split("\n") assert messages_request1[0] == "Request URL: 'http://localhost/'" assert messages_request1[1] == "Request method: 'GET'" - assert messages_request1[2] == 'Request headers:' - assert messages_request1[3] == 'No body was attached to the request' - assert messages_response1[0] == 'Response status: 202' - assert messages_response1[1] == 'Response headers:' + assert messages_request1[2] == "Request headers:" + assert messages_request1[3] == "No body was attached to the request" + assert messages_response1[0] == "Response status: 202" + assert messages_response1[1] == "Response headers:" assert messages_request2[0] == "Request URL: 'http://localhost/'" assert messages_request2[1] == "Request method: 'GET'" - assert messages_request2[2] == 'Request headers:' - assert messages_request2[3] == 'No body was attached to the request' - assert messages_response2[0] == 'Response status: 202' - assert messages_response2[1] == 'Response headers:' + assert messages_request2[2] == "Request headers:" + assert messages_request2[3] == "No body was attached to the request" + assert messages_response2[0] == "Response status: 202" + assert messages_response2[1] == "Response headers:" mock_handler.reset() + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_http_logger_with_body(http_request, http_response): - class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() self.messages = [] + def reset(self): self.messages = [] + def emit(self, record): self.messages.append(record) + mock_handler = MockHandler() logger = logging.getLogger("testlogger") @@ -237,7 +231,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") universal_request.body = "testbody" http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 @@ -247,16 +241,16 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") assert messages_request[0] == "Request URL: 'http://localhost/'" assert messages_request[1] == "Request method: 'GET'" - assert messages_request[2] == 'Request headers:' - assert messages_request[3] == 'A body is sent with the request' - assert messages_response[0] == 'Response status: 202' - assert messages_response[1] == 'Response headers:' + assert messages_request[2] == "Request headers:" + assert messages_request[3] == "A body is sent with the request" + assert messages_response[0] == "Response status: 202" + assert messages_response[1] == "Response headers:" mock_handler.reset() @@ -264,15 +258,17 @@ def emit(self, record): @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) @pytest.mark.skipif(sys.version_info < (3, 6), reason="types.AsyncGeneratorType does not exist in 3.5") def test_http_logger_with_generator_body(http_request, http_response): - class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() self.messages = [] + def reset(self): self.messages = [] + def emit(self, record): self.messages.append(record) + mock_handler = MockHandler() logger = logging.getLogger("testlogger") @@ -281,7 +277,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") mock = Mock() mock.__class__ = types.AsyncGeneratorType universal_request.body = mock @@ -293,15 +289,15 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") assert messages_request[0] == "Request URL: 'http://localhost/'" assert messages_request[1] == "Request method: 'GET'" - assert messages_request[2] == 'Request headers:' - assert messages_request[3] == 'File upload' - assert messages_response[0] == 'Response status: 202' - assert messages_response[1] == 'Response headers:' + assert messages_request[2] == "Request headers:" + assert messages_request[3] == "File upload" + assert messages_response[0] == "Response status: 202" + assert messages_response[1] == "Response headers:" mock_handler.reset() diff --git a/sdk/core/azure-core/tests/async_tests/test_paging_async.py b/sdk/core/azure-core/tests/async_tests/test_paging_async.py index 0cb60339aa8b..3f372ddba8fe 100644 --- a/sdk/core/azure-core/tests/async_tests/test_paging_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_paging_async.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from typing import AsyncIterator, TypeVar, List @@ -48,96 +48,72 @@ async def _as_list(async_iter: AsyncIterator[T]) -> List[T]: class TestPaging: - @pytest.mark.asyncio async def test_basic_paging(self): - async def get_next(continuation_token=None): - """Simplify my life and return JSON and not response, but should be response. - """ + """Simplify my life and return JSON and not response, but should be response.""" if not continuation_token: - return { - 'nextLink': 'page2', - 'value': ['value1.0', 'value1.1'] - } + return {"nextLink": "page2", "value": ["value1.0", "value1.1"]} else: - return { - 'nextLink': None, - 'value': ['value2.0', 'value2.1'] - } + return {"nextLink": None, "value": ["value2.0", "value2.1"]} async def extract_data(response): - return response['nextLink'], AsyncList(response['value']) + return response["nextLink"], AsyncList(response["value"]) pager = AsyncItemPaged(get_next, extract_data) result_iterated = await _as_list(pager) - assert ['value1.0', 'value1.1', 'value2.0', 'value2.1'] == result_iterated + assert ["value1.0", "value1.1", "value2.0", "value2.1"] == result_iterated @pytest.mark.asyncio async def test_advance_paging(self): - async def get_next(continuation_token=None): - """Simplify my life and return JSON and not response, but should be response. - """ + """Simplify my life and return JSON and not response, but should be response.""" if not continuation_token: - return { - 'nextLink': 'page2', - 'value': ['value1.0', 'value1.1'] - } + return {"nextLink": "page2", "value": ["value1.0", "value1.1"]} else: - return { - 'nextLink': None, - 'value': ['value2.0', 'value2.1'] - } + return {"nextLink": None, "value": ["value2.0", "value2.1"]} async def extract_data(response): - return response['nextLink'], AsyncList(response['value']) + return response["nextLink"], AsyncList(response["value"]) pager = AsyncItemPaged(get_next, extract_data).by_page() page1 = await pager.__anext__() - assert ['value1.0', 'value1.1'] == await _as_list(page1) + assert ["value1.0", "value1.1"] == await _as_list(page1) page2 = await pager.__anext__() - assert ['value2.0', 'value2.1'] == await _as_list(page2) + assert ["value2.0", "value2.1"] == await _as_list(page2) with pytest.raises(StopAsyncIteration): await pager.__anext__() - @pytest.mark.asyncio async def test_none_value(self): async def get_next(continuation_token=None): - return { - 'nextLink': None, - 'value': None - } + return {"nextLink": None, "value": None} async def extract_data(response): - return response['nextLink'], AsyncList(response['value'] or []) + return response["nextLink"], AsyncList(response["value"] or []) pager = AsyncItemPaged(get_next, extract_data) result_iterated = await _as_list(pager) assert len(result_iterated) == 0 - @pytest.mark.asyncio async def test_paging_continue_on_error(self): async def get_next(continuation_token=None): if not continuation_token: - return { - 'nextLink': 'foo', - 'value': ['bar'] - } + return {"nextLink": "foo", "value": ["bar"]} else: raise HttpResponseError() + async def extract_data(response): - return response['nextLink'], iter(response['value'] or []) - + return response["nextLink"], iter(response["value"] or []) + pager = AsyncItemPaged(get_next, extract_data) - assert await pager.__anext__() == 'bar' + assert await pager.__anext__() == "bar" with pytest.raises(HttpResponseError) as err: await pager.__anext__() - assert err.value.continuation_token == 'foo' + assert err.value.continuation_token == "foo" diff --git a/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py b/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py index d72f716052da..0dfbbfdf01e1 100644 --- a/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import sys from azure.core.pipeline import AsyncPipeline @@ -40,7 +40,7 @@ AsyncHttpTransport, AsyncioRequestsTransport, TrioRequestsTransport, - AioHttpTransport + AioHttpTransport, ) from azure.core.polling.async_base_polling import AsyncLROBasePolling @@ -76,7 +76,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): pipeline = AsyncPipeline(BrokenSender(), [SansIOHTTPPolicy()]) - req = http_request('GET', '/') + req = http_request("GET", "/") with pytest.raises(ValueError): await pipeline.run(req) @@ -89,15 +89,13 @@ def on_exception(self, requests, **kwargs): with pytest.raises(NotImplementedError): await pipeline.run(req) + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_basic_aiohttp(port, http_request): request = http_request("GET", "http://localhost:{}/basic/string".format(port)) - policies = [ - UserAgentPolicy("myusergant"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myusergant"), AsyncRedirectPolicy()] async with AsyncPipeline(AioHttpTransport(), policies=policies) as pipeline: response = await pipeline.run(request) @@ -105,16 +103,14 @@ async def test_basic_aiohttp(port, http_request): # all we need to check is if we are able to make the call assert isinstance(response.http_response.status_code, int) + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_basic_aiohttp_separate_session(port, http_request): session = aiohttp.ClientSession() request = http_request("GET", "http://localhost:{}/basic/string".format(port)) - policies = [ - UserAgentPolicy("myusergant"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myusergant"), AsyncRedirectPolicy()] transport = AioHttpTransport(session=session, session_owner=False) async with AsyncPipeline(transport, policies=policies) as pipeline: response = await pipeline.run(request) @@ -125,20 +121,19 @@ async def test_basic_aiohttp_separate_session(port, http_request): assert transport.session await transport.session.close() + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_basic_async_requests(port, http_request): request = http_request("GET", "http://localhost:{}/basic/string".format(port)) - policies = [ - UserAgentPolicy("myusergant"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myusergant"), AsyncRedirectPolicy()] async with AsyncPipeline(AsyncioRequestsTransport(), policies=policies) as pipeline: response = await pipeline.run(request) assert isinstance(response.http_response.status_code, int) + @pytest.mark.asyncio async def test_async_transport_sleep(): @@ -148,26 +143,27 @@ async def test_async_transport_sleep(): async with AioHttpTransport() as transport: await transport.sleep(1) + def test_polling_with_path_format_arguments(): - method = AsyncLROBasePolling( - timeout=0, - path_format_arguments={"host": "host:3000", "accountName": "local"} - ) + method = AsyncLROBasePolling(timeout=0, path_format_arguments={"host": "host:3000", "accountName": "local"}) client = AsyncPipelineClient(base_url="http://{accountName}{host}") method._operation = LocationPolling() method._operation._location_url = "/results/1" method._client = client - assert "http://localhost:3000/results/1" == method._client.format_url(method._operation.get_polling_url(), **method._path_format_arguments) + assert "http://localhost:3000/results/1" == method._client.format_url( + method._operation.get_polling_url(), **method._path_format_arguments + ) -def test_async_trio_transport_sleep(): +def test_async_trio_transport_sleep(): async def do(): async with TrioRequestsTransport() as transport: await transport.sleep(1) response = trio.run(do) + def test_default_http_logging_policy(): config = Configuration() pipeline_client = AsyncPipelineClient(base_url="test") @@ -176,70 +172,70 @@ def test_default_http_logging_policy(): assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_ALLOWLIST + def test_pass_in_http_logging_policy(): config = Configuration() http_logging_policy = HttpLoggingPolicy() - http_logging_policy.allowed_header_names.update( - {"x-ms-added-header"} - ) + http_logging_policy.allowed_header_names.update({"x-ms-added-header"}) config.http_logging_policy = http_logging_policy pipeline_client = AsyncPipelineClient(base_url="test") pipeline = pipeline_client._build_pipeline(config) http_logging_policy = pipeline._impl_policies[-1]._policy - assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) - assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_ALLOWLIST.union({"x-ms-added-header"}) + assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union( + {"x-ms-added-header"} + ) + assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_ALLOWLIST.union( + {"x-ms-added-header"} + ) + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_conf_async_requests(port, http_request): request = http_request("GET", "http://localhost:{}/basic/string".format(port)) - policies = [ - UserAgentPolicy("myusergant"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myusergant"), AsyncRedirectPolicy()] async with AsyncPipeline(AsyncioRequestsTransport(), policies=policies) as pipeline: response = await pipeline.run(request) assert isinstance(response.http_response.status_code, int) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_conf_async_trio_requests(port, http_request): - async def do(): request = http_request("GET", "http://localhost:{}/basic/string".format(port)) - policies = [ - UserAgentPolicy("myusergant"), - AsyncRedirectPolicy() - ] + policies = [UserAgentPolicy("myusergant"), AsyncRedirectPolicy()] async with AsyncPipeline(TrioRequestsTransport(), policies=policies) as pipeline: return await pipeline.run(request) response = trio.run(do) assert isinstance(response.http_response.status_code, int) + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_retry_without_http_response(http_request): class NaughtyPolicy(AsyncHTTPPolicy): def send(*args): - raise AzureError('boo') + raise AzureError("boo") policies = [AsyncRetryPolicy(), NaughtyPolicy()] pipeline = AsyncPipeline(policies=policies, transport=None) with pytest.raises(AzureError): - await pipeline.run(http_request('GET', url='https://foo.bar')) + await pipeline.run(http_request("GET", url="https://foo.bar")) + @pytest.mark.asyncio async def test_add_custom_policy(): class BooPolicy(AsyncHTTPPolicy): def send(*args): - raise AzureError('boo') + raise AzureError("boo") class FooPolicy(AsyncHTTPPolicy): def send(*args): - raise AzureError('boo') + raise AzureError("boo") config = Configuration() retry_policy = AsyncRetryPolicy() @@ -274,8 +270,9 @@ def send(*args): pos_retry = policies.index(retry_policy) assert pos_boo > pos_retry - client = AsyncPipelineClient(base_url="test", config=config, per_call_policies=boo_policy, - per_retry_policies=foo_policy) + client = AsyncPipelineClient( + base_url="test", config=config, per_call_policies=boo_policy, per_retry_policies=foo_policy + ) policies = client._pipeline._impl_policies assert boo_policy in policies assert foo_policy in policies @@ -285,8 +282,9 @@ def send(*args): assert pos_boo < pos_retry assert pos_foo > pos_retry - client = AsyncPipelineClient(base_url="test", config=config, per_call_policies=[boo_policy], - per_retry_policies=[foo_policy]) + client = AsyncPipelineClient( + base_url="test", config=config, per_call_policies=[boo_policy], per_retry_policies=[foo_policy] + ) policies = client._pipeline._impl_policies assert boo_policy in policies assert foo_policy in policies @@ -296,9 +294,7 @@ def send(*args): assert pos_boo < pos_retry assert pos_foo > pos_retry - policies = [UserAgentPolicy(), - AsyncRetryPolicy(), - DistributedTracingPolicy()] + policies = [UserAgentPolicy(), AsyncRetryPolicy(), DistributedTracingPolicy()] client = AsyncPipelineClient(base_url="test", policies=policies, per_call_policies=boo_policy) actual_policies = client._pipeline._impl_policies assert boo_policy == actual_policies[0] @@ -313,19 +309,20 @@ def send(*args): actual_policies = client._pipeline._impl_policies assert foo_policy == actual_policies[2] - client = AsyncPipelineClient(base_url="test", policies=policies, per_call_policies=boo_policy, - per_retry_policies=[foo_policy]) + client = AsyncPipelineClient( + base_url="test", policies=policies, per_call_policies=boo_policy, per_retry_policies=[foo_policy] + ) actual_policies = client._pipeline._impl_policies assert boo_policy == actual_policies[0] assert foo_policy == actual_policies[3] - client = AsyncPipelineClient(base_url="test", policies=policies, per_call_policies=[boo_policy], - per_retry_policies=[foo_policy]) + client = AsyncPipelineClient( + base_url="test", policies=policies, per_call_policies=[boo_policy], per_retry_policies=[foo_policy] + ) actual_policies = client._pipeline._impl_policies assert boo_policy == actual_policies[0] assert foo_policy == actual_policies[3] - policies = [UserAgentPolicy(), - DistributedTracingPolicy()] + policies = [UserAgentPolicy(), DistributedTracingPolicy()] with pytest.raises(ValueError): client = AsyncPipelineClient(base_url="test", policies=policies, per_retry_policies=foo_policy) with pytest.raises(ValueError): diff --git a/sdk/core/azure-core/tests/async_tests/test_polling_async.py b/sdk/core/azure-core/tests/async_tests/test_polling_async.py index 469f30ba42f3..45f1227fbb7c 100644 --- a/sdk/core/azure-core/tests/async_tests/test_polling_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_polling_async.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,9 +22,10 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import asyncio import time + try: from unittest import mock except ImportError: @@ -35,6 +36,7 @@ from azure.core import AsyncPipelineClient from azure.core.polling import * from azure.core.exceptions import ServiceResponseError + # from msrest.serialization import Model @@ -49,34 +51,33 @@ async def test_no_polling(client): no_polling = AsyncNoPolling() initial_response = "initial response" + def deserialization_cb(response): assert response == initial_response - return "Treated: "+response + return "Treated: " + response no_polling.initialize(client, initial_response, deserialization_cb) - await no_polling.run() # Should no raise and do nothing + await no_polling.run() # Should no raise and do nothing assert no_polling.status() == "succeeded" assert no_polling.finished() - assert no_polling.resource() == "Treated: "+initial_response + assert no_polling.resource() == "Treated: " + initial_response continuation_token = no_polling.get_continuation_token() assert isinstance(continuation_token, str) no_polling_revived_args = NoPolling.from_continuation_token( - continuation_token, - deserialization_callback=deserialization_cb, - client=client + continuation_token, deserialization_callback=deserialization_cb, client=client ) no_polling_revived = NoPolling() no_polling_revived.initialize(*no_polling_revived_args) assert no_polling_revived.status() == "succeeded" assert no_polling_revived.finished() - assert no_polling_revived.resource() == "Treated: "+initial_response + assert no_polling_revived.resource() == "Treated: " + initial_response class PollingTwoSteps(AsyncPollingMethod): - """An empty poller that returns the deserialized initial response. - """ + """An empty poller that returns the deserialized initial response.""" + def __init__(self, sleep=0): self._initial_response = None self._deserialization_callback = None @@ -88,10 +89,9 @@ def initialize(self, _, initial_response, deserialization_callback): self._finished = False async def run(self): - """Empty run, no polling. - """ + """Empty run, no polling.""" self._finished = True - await asyncio.sleep(self._sleep) # Give me time to add callbacks! + await asyncio.sleep(self._sleep) # Give me time to add callbacks! def status(self): """Return the current status as a string. @@ -115,7 +115,7 @@ def get_continuation_token(self): def from_continuation_token(cls, continuation_token, **kwargs): # type(str, Any) -> Tuple initial_response = continuation_token - deserialization_callback = kwargs['deserialization_callback'] + deserialization_callback = kwargs["deserialization_callback"] return None, initial_response, deserialization_callback @@ -128,7 +128,7 @@ async def test_poller(client): # Same for deserialization_callback, just pass to the polling_method def deserialization_callback(response): assert response == initial_response - return "Treated: "+response + return "Treated: " + response method = AsyncNoPolling() @@ -140,7 +140,7 @@ def deserialization_callback(response): result = await poller assert poller.done() - assert result == "Treated: "+initial_response + assert result == "Treated: " + initial_response assert raw_poller.status() == "succeeded" assert raw_poller.polling_method() is method done_cb.assert_called_once_with(poller) @@ -160,7 +160,7 @@ def deserialization_callback(response): poller.remove_done_callback(done_cb2) result = await poller - assert result == "Treated: "+initial_response + assert result == "Treated: " + initial_response assert raw_poller.status() == "succeeded" done_cb.assert_called_once_with(poller) done_cb2.assert_not_called() @@ -174,23 +174,23 @@ def deserialization_callback(response): client=client, initial_response=initial_response, deserialization_callback=deserialization_callback, - polling_method=method + polling_method=method, ) result = await new_poller.result() - assert result == "Treated: "+initial_response + assert result == "Treated: " + initial_response assert new_poller.status() == "succeeded" @pytest.mark.asyncio async def test_broken_poller(client): - class NoPollingError(PollingTwoSteps): async def run(self): raise ValueError("Something bad happened") initial_response = "Initial response" + def deserialization_callback(response): - return "Treated: "+response + return "Treated: " + response method = NoPollingError() poller = AsyncLROPoller(client, initial_response, deserialization_callback, method) @@ -202,14 +202,14 @@ def deserialization_callback(response): @pytest.mark.asyncio async def test_async_poller_error_continuation(client): - class NoPollingError(PollingTwoSteps): async def run(self): raise ServiceResponseError("Something bad happened") initial_response = "Initial response" + def deserialization_callback(response): - return "Treated: "+response + return "Treated: " + response method = NoPollingError() poller = AsyncLROPoller(client, initial_response, deserialization_callback, method) diff --git a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py index 189a09eaccb4..753c0e146acb 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py @@ -28,18 +28,19 @@ async def __anext__(self): raise StopAsyncIteration async with AsyncioRequestsTransport() as transport: - req = http_request('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) + req = http_request("GET", "http://localhost:{}/basic/anything".format(port), data=AsyncGen()) response = await transport.send(req) if is_rest(http_request): assert is_rest(response) - assert json.loads(response.text())['data'] == "azerty" + assert json.loads(response.text())["data"] == "azerty" + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_send_data(port, http_request): async with AsyncioRequestsTransport() as transport: - req = http_request('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") + req = http_request("PUT", "http://localhost:{}/basic/anything".format(port), data=b"azerty") response = await transport.send(req) if is_rest(http_request): assert is_rest(response) - assert json.loads(response.text())['data'] == "azerty" + assert json.loads(response.text())["data"] == "azerty" diff --git a/sdk/core/azure-core/tests/async_tests/test_request_trio.py b/sdk/core/azure-core/tests/async_tests/test_request_trio.py index 092a550cb67b..668e07288824 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_trio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_trio.py @@ -29,18 +29,19 @@ async def __anext__(self): raise StopAsyncIteration async with TrioRequestsTransport() as transport: - req = http_request('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) + req = http_request("GET", "http://localhost:{}/basic/anything".format(port), data=AsyncGen()) response = await transport.send(req) if is_rest(http_request): assert is_rest(response) - assert json.loads(response.text())['data'] == "azerty" + assert json.loads(response.text())["data"] == "azerty" + @pytest.mark.trio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_send_data(port, http_request): async with TrioRequestsTransport() as transport: - req = http_request('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") + req = http_request("PUT", "http://localhost:{}/basic/anything".format(port), data=b"azerty") response = await transport.send(req) if is_rest(http_request): assert is_rest(response) - assert json.loads(response.text())['data'] == "azerty" \ No newline at end of file + assert json.loads(response.text())["data"] == "azerty" diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py b/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py index cfa7af589937..194efb4a8dde 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_asyncio_transport.py @@ -11,12 +11,14 @@ import pytest from utils import readonly_checks + @pytest.fixture async def client(port): async with AsyncioRequestsTransport() as transport: async with AsyncTestRestClient(port, transport=transport) as client: yield client + @pytest.mark.asyncio async def test_async_gen_data(port, client): class AsyncGen: @@ -32,16 +34,18 @@ async def __anext__(self): except StopIteration: raise StopAsyncIteration - request = HttpRequest('GET', 'http://localhost:{}/basic/anything'.format(port), content=AsyncGen()) + request = HttpRequest("GET", "http://localhost:{}/basic/anything".format(port), content=AsyncGen()) response = await client.send_request(request) - assert response.json()['data'] == "azerty" + assert response.json()["data"] == "azerty" + @pytest.mark.asyncio async def test_send_data(port, client): - request = HttpRequest('PUT', 'http://localhost:{}/basic/anything'.format(port), content=b"azerty") + request = HttpRequest("PUT", "http://localhost:{}/basic/anything".format(port), content=b"azerty") response = await client.send_request(request) - assert response.json()['data'] == "azerty" + assert response.json()["data"] == "azerty" + @pytest.mark.asyncio async def test_readonly(client): @@ -51,8 +55,10 @@ async def test_readonly(client): assert isinstance(response, RestAsyncioRequestsTransportResponse) from azure.core.pipeline.transport import AsyncioRequestsTransportResponse + readonly_checks(response, old_response_class=AsyncioRequestsTransportResponse) + @pytest.mark.asyncio async def test_decompress_compressed_header(client): # expect plain text @@ -63,6 +69,7 @@ async def test_decompress_compressed_header(client): assert response.content == content assert response.text() == "hello world" + @pytest.mark.asyncio async def test_deflate_decompress_compressed_header(client): # expect plain text @@ -73,6 +80,7 @@ async def test_deflate_decompress_compressed_header(client): assert response.content == content assert response.text() == "hi there" + @pytest.mark.asyncio async def test_decompress_compressed_header_stream(client): # expect plain text @@ -83,6 +91,7 @@ async def test_decompress_compressed_header_stream(client): assert response.content == content assert response.text() == "hello world" + @pytest.mark.asyncio async def test_decompress_compressed_header_stream_body_content(client): # expect plain text diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_context_manager_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_context_manager_async.py index 6866baa41f42..5aebfd5b9ed3 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_context_manager_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_context_manager_async.py @@ -9,12 +9,14 @@ from azure.core.rest import HttpRequest from rest_client_async import AsyncTestRestClient + @pytest.mark.asyncio async def test_normal_call(client): async def _raise_and_get_text(response): response.raise_for_status() assert response.text() == "Hello, world!" assert response.is_closed + request = HttpRequest("GET", url="/basic/string") response = await client.send_request(request) await _raise_and_get_text(response) @@ -27,6 +29,7 @@ async def _raise_and_get_text(response): async with response as response: await _raise_and_get_text(response) + @pytest.mark.asyncio async def test_stream_call(client): async def _raise_and_get_text(response): @@ -37,6 +40,7 @@ async def _raise_and_get_text(response): await response.read() assert response.text() == "Hello, world!" assert response.is_closed + request = HttpRequest("GET", url="/streams/basic") response = await client.send_request(request, stream=True) await _raise_and_get_text(response) @@ -50,6 +54,7 @@ async def _raise_and_get_text(response): async with response as response: await _raise_and_get_text(response) + # TODO: commenting until https://github.com/Azure/azure-sdk-for-python/issues/18086 is fixed # @pytest.mark.asyncio @@ -79,4 +84,4 @@ async def _raise_and_get_text(response): # assert error.error.code == "BadRequest" # assert error.error.message == "You made a bad request" # assert error.model.code == "BadRequest" -# assert error.error.message == "You made a bad request" \ No newline at end of file +# assert error.error.message == "You made a bad request" diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_headers_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_headers_async.py index 3b61c59b69d5..a0156b8ed738 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_headers_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_headers_async.py @@ -19,6 +19,7 @@ "Date", ] + @pytest.fixture def get_response_headers(client): async def _get_response_headers(request): @@ -27,8 +28,10 @@ async def _get_response_headers(request): for header in RESPONSE_HEADERS_TO_IGNORE: response.headers.pop(header, None) return response.headers + return _get_response_headers + @pytest.mark.asyncio async def test_headers_response(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) @@ -45,7 +48,7 @@ async def test_headers_response(get_response_headers): assert h.get("nope", default="default") is "default" assert h.get("nope", default=None) is None assert h.get("nope", default=[]) == [] - assert list(h) == ['a', 'b'] + assert list(h) == ["a", "b"] assert list(h.keys()) == ["a", "b"] assert list(h.values()) == ["123, 456", "789"] @@ -53,6 +56,7 @@ async def test_headers_response(get_response_headers): assert list(h) == ["a", "b"] assert dict(h) == {"a": "123, 456", "b": "789"} + @pytest.mark.asyncio async def test_headers_response_keys(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) @@ -66,14 +70,18 @@ async def test_headers_response_keys(get_response_headers): assert "B" in h.keys() assert set(h.keys()) == set(ref_dict.keys()) -@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="https://github.com/aio-libs/aiohttp/issues/5967") + +@pytest.mark.skipif( + platform.python_implementation() == "PyPy", reason="https://github.com/aio-libs/aiohttp/issues/5967" +) @pytest.mark.asyncio async def test_headers_response_keys_mutability(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) # test mutability before_mutation_keys = h.keys() - h['c'] = '000' - assert 'c' in before_mutation_keys + h["c"] = "000" + assert "c" in before_mutation_keys + @pytest.mark.asyncio async def test_headers_response_values(get_response_headers): @@ -82,18 +90,22 @@ async def test_headers_response_values(get_response_headers): ref_dict = {"a": "123, 456", "b": "789"} assert list(h.values()) == list(ref_dict.values()) assert repr(h.values()) == repr(ref_dict.values()) - assert '123, 456' in h.values() - assert '789' in h.values() + assert "123, 456" in h.values() + assert "789" in h.values() assert set(h.values()) == set(ref_dict.values()) -@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="https://github.com/aio-libs/aiohttp/issues/5967") + +@pytest.mark.skipif( + platform.python_implementation() == "PyPy", reason="https://github.com/aio-libs/aiohttp/issues/5967" +) @pytest.mark.asyncio async def test_headers_response_values_mutability(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) # test mutability before_mutation_values = h.values() - h['c'] = '000' - assert '000' in before_mutation_values + h["c"] = "000" + assert "000" in before_mutation_values + @pytest.mark.asyncio async def test_headers_response_items(get_response_headers): @@ -102,22 +114,26 @@ async def test_headers_response_items(get_response_headers): ref_dict = {"a": "123, 456", "b": "789"} assert list(h.items()) == list(ref_dict.items()) assert repr(h.items()) == repr(ref_dict.items()) - assert ("a", '123, 456') in h.items() - assert not ("a", '123, 456', '123, 456') in h.items() + assert ("a", "123, 456") in h.items() + assert not ("a", "123, 456", "123, 456") in h.items() assert not {"a": "blah", "123, 456": "blah"} in h.items() - assert ("A", '123, 456') in h.items() - assert ("b", '789') in h.items() - assert ("B", '789') in h.items() + assert ("A", "123, 456") in h.items() + assert ("b", "789") in h.items() + assert ("B", "789") in h.items() assert set(h.items()) == set(ref_dict.items()) -@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="https://github.com/aio-libs/aiohttp/issues/5967") + +@pytest.mark.skipif( + platform.python_implementation() == "PyPy", reason="https://github.com/aio-libs/aiohttp/issues/5967" +) @pytest.mark.asyncio async def test_headers_response_items_mutability(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) # test mutability before_mutation_items = h.items() - h['c'] = '000' - assert ('c', '000') in before_mutation_items + h["c"] = "000" + assert ("c", "000") in before_mutation_items + @pytest.mark.asyncio async def test_header_mutations(get_response_headers): @@ -134,6 +150,7 @@ async def test_header_mutations(get_response_headers): del h["a"] assert dict(h) == {"b": "4"} + @pytest.mark.asyncio async def test_copy_headers_method(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/case-insensitive")) @@ -141,77 +158,80 @@ async def test_copy_headers_method(get_response_headers): assert h == headers_copy assert h is not headers_copy + @pytest.mark.asyncio async def test_headers_insert_retains_ordering(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/ordered")) h["b"] = "123" assert list(h.values()) == ["a", "123", "c"] + @pytest.mark.asyncio async def test_headers_insert_appends_if_new(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/case-insensitive")) h["d"] = "123" assert list(h.values()) == ["lowercase", "ALLCAPS", "camelCase", "123"] + @pytest.mark.asyncio async def test_headers_insert_removes_all_existing(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) h["a"] = "789" assert dict(h) == {"a": "789", "b": "789"} + @pytest.mark.asyncio async def test_headers_delete_removes_all_existing(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) del h["a"] assert dict(h) == {"b": "789"} + @pytest.mark.asyncio async def test_response_headers_case_insensitive(client): request = HttpRequest("GET", "/headers/case-insensitive") response = await client.send_request(request) response.raise_for_status() assert ( - response.headers["lowercase-header"] == - response.headers["LOWERCASE-HEADER"] == - response.headers["Lowercase-Header"] == - response.headers["lOwErCasE-HeADer"] == - "lowercase" + response.headers["lowercase-header"] + == response.headers["LOWERCASE-HEADER"] + == response.headers["Lowercase-Header"] + == response.headers["lOwErCasE-HeADer"] + == "lowercase" ) assert ( - response.headers["allcaps-header"] == - response.headers["ALLCAPS-HEADER"] == - response.headers["Allcaps-Header"] == - response.headers["AlLCapS-HeADer"] == - "ALLCAPS" + response.headers["allcaps-header"] + == response.headers["ALLCAPS-HEADER"] + == response.headers["Allcaps-Header"] + == response.headers["AlLCapS-HeADer"] + == "ALLCAPS" ) assert ( - response.headers["camelcase-header"] == - response.headers["CAMELCASE-HEADER"] == - response.headers["CamelCase-Header"] == - response.headers["cAMeLCaSE-hEadER"] == - "camelCase" + response.headers["camelcase-header"] + == response.headers["CAMELCASE-HEADER"] + == response.headers["CamelCase-Header"] + == response.headers["cAMeLCaSE-hEadER"] + == "camelCase" ) return response + @pytest.mark.asyncio async def test_multiple_headers_duplicate_case_insensitive(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/case-insensitive")) - assert ( - h["Duplicate-Header"] == - h['duplicate-header'] == - h['DupLicAte-HeaDER'] == - "one, two, three" - ) + assert h["Duplicate-Header"] == h["duplicate-header"] == h["DupLicAte-HeaDER"] == "one, two, three" + @pytest.mark.asyncio async def test_multiple_headers_commas(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/commas")) assert h["Set-Cookie"] == "a, b, c" + @pytest.mark.asyncio async def test_update(get_response_headers): h = await get_response_headers(HttpRequest("GET", "/headers/duplicate/commas")) assert h["Set-Cookie"] == "a, b, c" h.update({"Set-Cookie": "override", "new-key": "new-value"}) - assert h['Set-Cookie'] == 'override' + assert h["Set-Cookie"] == "override" assert h["new-key"] == "new-value" diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_http_request_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_http_request_async.py index 5ec0e71d0843..84f6926c7a32 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_http_request_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_http_request_async.py @@ -10,6 +10,7 @@ from azure.core.rest import HttpRequest import collections.abc + @pytest.fixture def assert_aiterator_body(): async def _comparer(request, final_value): @@ -18,8 +19,10 @@ async def _comparer(request, final_value): parts.append(part) content = b"".join(parts) assert content == final_value + return _comparer + def test_transfer_encoding_header(): async def streaming_body(data): yield data # pragma: nocover @@ -29,6 +32,7 @@ async def streaming_body(data): request = HttpRequest("POST", "http://example.org", data=data) assert "Content-Length" not in request.headers + def test_override_content_length_header(): async def streaming_body(data): yield data # pragma: nocover @@ -39,8 +43,9 @@ async def streaming_body(data): request = HttpRequest("POST", "http://example.org", data=data, headers=headers) assert request.headers["Content-Length"] == "0" + @pytest.mark.asyncio -async def test_aiterable_content(assert_aiterator_body): # cspell:disable-line +async def test_aiterable_content(assert_aiterator_body): # cspell:disable-line class Content: async def __aiter__(self): yield b"test 123" @@ -49,6 +54,7 @@ async def __aiter__(self): assert request.headers == {} await assert_aiterator_body(request, b"test 123") + @pytest.mark.asyncio async def test_aiterator_content(assert_aiterator_body): async def hello_world(): @@ -78,6 +84,7 @@ async def hello_world(): assert request.headers == {} await assert_aiterator_body(request, b"Hello, world!") + @pytest.mark.asyncio async def test_read_content(assert_aiterator_body): async def content(): diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py index aac14a6c27f6..5945aafb4abf 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_http_response_async.py @@ -14,6 +14,7 @@ from azure.core.exceptions import HttpResponseError from utils import readonly_checks + @pytest.fixture def send_request(client): async def _send_request(request): @@ -21,8 +22,10 @@ async def _send_request(request): response = await client.send_request(request, stream=False) response.raise_for_status() return response + return _send_request + @pytest.mark.asyncio async def test_response(send_request, port): response = await send_request( @@ -35,6 +38,7 @@ async def test_response(send_request, port): assert response.request.method == "GET" assert response.request.url == "http://localhost:{}/basic/string".format(port) + @pytest.mark.asyncio async def test_response_content(send_request): response = await send_request( @@ -46,6 +50,7 @@ async def test_response_content(send_request): assert content == b"Hello, world!" assert response.text() == "Hello, world!" + @pytest.mark.asyncio async def test_response_text(send_request): response = await send_request( @@ -56,8 +61,9 @@ async def test_response_text(send_request): content = await response.read() assert content == b"Hello, world!" assert response.text() == "Hello, world!" - assert response.headers["Content-Length"] == '13' - assert response.headers['Content-Type'] == "text/plain; charset=utf-8" + assert response.headers["Content-Length"] == "13" + assert response.headers["Content-Type"] == "text/plain; charset=utf-8" + @pytest.mark.asyncio async def test_response_html(send_request): @@ -70,6 +76,7 @@ async def test_response_html(send_request): assert content == b"Hello, world!" assert response.text() == "Hello, world!" + @pytest.mark.asyncio async def test_raise_for_status(client): # response = await client.send_request( @@ -92,23 +99,21 @@ async def test_raise_for_status(client): with pytest.raises(HttpResponseError): response.raise_for_status() + @pytest.mark.asyncio async def test_response_repr(send_request): - response = await send_request( - HttpRequest("GET", "/basic/string") - ) + response = await send_request(HttpRequest("GET", "/basic/string")) assert repr(response) == "" + @pytest.mark.asyncio async def test_response_content_type_encoding(send_request): """ Use the charset encoding in the Content-Type header if possible. """ - response = await send_request( - request=HttpRequest("GET", "/encoding/latin-1") - ) + response = await send_request(request=HttpRequest("GET", "/encoding/latin-1")) assert response.content_type == "text/plain; charset=latin-1" - assert response.content == b'Latin 1: \xff' + assert response.content == b"Latin 1: \xff" assert response.text() == "Latin 1: ÿ" assert response.encoding == "latin-1" @@ -118,10 +123,8 @@ async def test_response_autodetect_encoding(send_request): """ Autodetect encoding if there is no Content-Type header. """ - response = await send_request( - request=HttpRequest("GET", "/encoding/latin-1") - ) - assert response.text() == u'Latin 1: ÿ' + response = await send_request(request=HttpRequest("GET", "/encoding/latin-1")) + assert response.text() == "Latin 1: ÿ" assert response.encoding == "latin-1" @@ -130,9 +133,7 @@ async def test_response_fallback_to_autodetect(send_request): """ Fallback to autodetection if we get an invalid charset in the Content-Type header. """ - response = await send_request( - request=HttpRequest("GET", "/encoding/invalid-codec-name") - ) + response = await send_request(request=HttpRequest("GET", "/encoding/invalid-codec-name")) assert response.headers["Content-Type"] == "text/plain; charset=invalid-codec-name" assert response.text() == "おはようございます。" assert response.encoding is None @@ -168,6 +169,7 @@ async def test_response_no_charset_with_iso_8859_1_content(send_request): assert response.text() == "Accented: �sterreich" assert response.encoding is None + @pytest.mark.asyncio async def test_json(send_request): response = await send_request( @@ -176,6 +178,7 @@ async def test_json(send_request): assert response.json() == {"greeting": "hello", "recipient": "world"} assert response.encoding is None + @pytest.mark.asyncio async def test_json_with_specified_encoding(send_request): response = await send_request( @@ -184,6 +187,7 @@ async def test_json_with_specified_encoding(send_request): assert response.json() == {"greeting": "hello", "recipient": "world"} assert response.encoding == "utf-16" + @pytest.mark.asyncio async def test_emoji(send_request): response = await send_request( @@ -191,6 +195,7 @@ async def test_emoji(send_request): ) assert response.text() == "👩" + @pytest.mark.asyncio async def test_emoji_family_with_skin_tone_modifier(send_request): response = await send_request( @@ -198,6 +203,7 @@ async def test_emoji_family_with_skin_tone_modifier(send_request): ) assert response.text() == "👩🏻‍👩🏽‍👧🏾‍👦🏿 SSN: 859-98-0987" + @pytest.mark.asyncio async def test_korean_nfc(send_request): response = await send_request( @@ -205,16 +211,16 @@ async def test_korean_nfc(send_request): ) assert response.text() == "아가" + @pytest.mark.asyncio async def test_urlencoded_content(send_request): await send_request( request=HttpRequest( - "POST", - "/urlencoded/pet/add/1", - data={ "pet_type": "dog", "pet_food": "meat", "name": "Fido", "pet_age": 42 } + "POST", "/urlencoded/pet/add/1", data={"pet_type": "dog", "pet_food": "meat", "name": "Fido", "pet_age": 42} ), ) + # @pytest.mark.asyncio # async def test_multipart_files_content(send_request): # request = HttpRequest( @@ -224,6 +230,7 @@ async def test_urlencoded_content(send_request): # ) # await send_request(request) + @pytest.mark.asyncio async def test_send_request_return_pipeline_response(client): # we use return_pipeline_response for some cases in autorest @@ -235,22 +242,24 @@ async def test_send_request_return_pipeline_response(client): assert response.http_response.text() == "Hello, world!" assert hasattr(response.http_request, "content") + @pytest.mark.asyncio async def test_text_and_encoding(send_request): response = await send_request( request=HttpRequest("GET", "/encoding/emoji"), ) - assert response.content == u"👩".encode("utf-8") - assert response.text() == u"👩" + assert response.content == "👩".encode("utf-8") + assert response.text() == "👩" # try setting encoding as a property response.encoding = "utf-16" - assert response.text() == u"鿰ꦑ" == response.content.decode(response.encoding) + assert response.text() == "鿰ꦑ" == response.content.decode(response.encoding) # assert latin-1 changes text decoding without changing encoding property - assert response.text("latin-1") == 'ð\x9f\x91©' == response.content.decode("latin-1") + assert response.text("latin-1") == "ð\x9f\x91©" == response.content.decode("latin-1") assert response.encoding == "utf-16" + # @pytest.mark.asyncio # async def test_multipart_encode_non_seekable_filelike(send_request): # """ @@ -278,11 +287,13 @@ async def test_text_and_encoding(send_request): # ) # await send_request(request) + def test_initialize_response_abc(): with pytest.raises(TypeError) as ex: AsyncHttpResponse() assert "Can't instantiate abstract class" in str(ex) + @pytest.mark.asyncio async def test_readonly(send_request): """Make sure everything that is readonly is readonly""" @@ -290,4 +301,5 @@ async def test_readonly(send_request): assert isinstance(response, RestAioHttpTransportResponse) from azure.core.pipeline.transport import AioHttpTransportResponse + readonly_checks(response, old_response_class=AioHttpTransportResponse) diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_polling_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_polling_async.py index b80ac0596768..1f435cf7acaf 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_polling_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_polling_async.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,39 +22,39 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import pytest from azure.core.exceptions import ServiceRequestError from azure.core.rest import HttpRequest from azure.core.polling import AsyncLROPoller from azure.core.polling.async_base_polling import AsyncLROBasePolling + @pytest.fixture def deserialization_callback(): def _callback(response): return response.http_response.json() + return _callback + @pytest.fixture def lro_poller(client, deserialization_callback): async def _callback(request, **kwargs): - initial_response = await client.send_request( - request=request, - _return_pipeline_response=True - ) + initial_response = await client.send_request(request=request, _return_pipeline_response=True) return AsyncLROPoller( - client._client, - initial_response, - deserialization_callback, - AsyncLROBasePolling(0, **kwargs) + client._client, initial_response, deserialization_callback, AsyncLROBasePolling(0, **kwargs) ) + return _callback + @pytest.mark.asyncio async def test_post_with_location_and_operation_location_headers(lro_poller): poller = await lro_poller(HttpRequest("POST", "/polling/post/location-and-operation-location")) result = await poller.result() - assert result == {'location_result': True} + assert result == {"location_result": True} + @pytest.mark.asyncio async def test_post_with_location_and_operation_location_headers_no_body(lro_poller): @@ -62,62 +62,74 @@ async def test_post_with_location_and_operation_location_headers_no_body(lro_pol result = await poller.result() assert result is None + @pytest.mark.asyncio async def test_post_resource_location(lro_poller): poller = await lro_poller(HttpRequest("POST", "/polling/post/resource-location")) result = await poller.result() - assert result == {'location_result': True} + assert result == {"location_result": True} + @pytest.mark.asyncio async def test_put_no_polling(lro_poller): result = await (await lro_poller(HttpRequest("PUT", "/polling/no-polling"))).result() - assert result['properties']['provisioningState'] == 'Succeeded' + assert result["properties"]["provisioningState"] == "Succeeded" + @pytest.mark.asyncio async def test_put_location(lro_poller): result = await (await lro_poller(HttpRequest("PUT", "/polling/location"))).result() - assert result['location_result'] + assert result["location_result"] + @pytest.mark.asyncio async def test_put_initial_response_body_invalid(lro_poller): # initial body is invalid result = await (await lro_poller(HttpRequest("PUT", "/polling/initial-body-invalid"))).result() - assert result['location_result'] + assert result["location_result"] + @pytest.mark.asyncio async def test_put_operation_location_polling_fail(lro_poller): with pytest.raises(ServiceRequestError): await (await lro_poller(HttpRequest("PUT", "/polling/bad-operation-location"), retry_total=0)).result() + @pytest.mark.asyncio async def test_put_location_polling_fail(lro_poller): with pytest.raises(ServiceRequestError): await (await lro_poller(HttpRequest("PUT", "/polling/bad-location"), retry_total=0)).result() + @pytest.mark.asyncio async def test_patch_location(lro_poller): result = await (await lro_poller(HttpRequest("PATCH", "/polling/location"))).result() - assert result['location_result'] + assert result["location_result"] + @pytest.mark.asyncio async def test_patch_operation_location_polling_fail(lro_poller): with pytest.raises(ServiceRequestError): await (await lro_poller(HttpRequest("PUT", "/polling/bad-operation-location"), retry_total=0)).result() + @pytest.mark.asyncio async def test_patch_location_polling_fail(lro_poller): with pytest.raises(ServiceRequestError): await (await lro_poller(HttpRequest("PUT", "/polling/bad-location"), retry_total=0)).result() + @pytest.mark.asyncio async def test_delete_operation_location(lro_poller): result = await (await lro_poller(HttpRequest("DELETE", "/polling/operation-location"))).result() - assert result['status'] == 'Succeeded' + assert result["status"] == "Succeeded" + @pytest.mark.asyncio async def test_request_id(lro_poller): result = await (await lro_poller(HttpRequest("POST", "/polling/request-id"), request_id="123456789")).result() - assert result['status'] == "Succeeded" + assert result["status"] == "Succeeded" + @pytest.mark.asyncio async def test_continuation_token(client, lro_poller, deserialization_callback): @@ -130,4 +142,4 @@ async def test_continuation_token(client, lro_poller, deserialization_callback): deserialization_callback=deserialization_callback, ) result = await new_poller.result() - assert result == {'location_result': True} + assert result == {"location_result": True} diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_response_backcompat_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_response_backcompat_async.py index 9d45ea118e98..20248b9848f2 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_response_backcompat_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_response_backcompat_async.py @@ -13,53 +13,64 @@ TRANSPORTS = [AioHttpTransport, AsyncioRequestsTransport] -@pytest.fixture +@pytest.fixture def old_request(port): return PipelineTransportHttpRequest("GET", "http://localhost:{}/streams/basic".format(port)) + @pytest.fixture @pytest.mark.asyncio async def get_old_response(old_request): async def _callback(transport, **kwargs): async with transport() as sender: return await sender.send(old_request, **kwargs) + return _callback + @pytest.fixture @pytest.mark.trio async def get_old_response_trio(old_request): async def _callback(**kwargs): async with TrioRequestsTransport() as sender: return await sender.send(old_request, **kwargs) + return _callback + @pytest.fixture def new_request(port): return RestHttpRequest("GET", "http://localhost:{}/streams/basic".format(port)) + @pytest.fixture @pytest.mark.asyncio async def get_new_response(new_request): async def _callback(transport, **kwargs): async with transport() as sender: return await sender.send(new_request, **kwargs) + return _callback + @pytest.fixture @pytest.mark.trio async def get_new_response_trio(new_request): async def _callback(**kwargs): async with TrioRequestsTransport() as sender: return await sender.send(new_request, **kwargs) + return _callback + def _test_response_attr_parity(old_response, new_response): for attr in dir(old_response): if not attr[0] == "_": # if not a private attr, we want parity assert hasattr(new_response, attr) + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_attr_parity(get_old_response, get_new_response, transport): @@ -67,12 +78,14 @@ async def test_response_attr_parity(get_old_response, get_new_response, transpor new_response = await get_new_response(transport) _test_response_attr_parity(old_response, new_response) + @pytest.mark.trio async def test_response_attr_parity_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_attr_parity(old_response, new_response) + def _test_response_set_attrs(old_response, new_response): for attr in dir(old_response): if attr[0] == "_": @@ -87,6 +100,7 @@ def _test_response_set_attrs(old_response, new_response): setattr(new_response, attr, "foo") assert getattr(old_response, attr) == getattr(new_response, attr) == "foo" + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_set_attrs(get_old_response, get_new_response, transport): @@ -94,18 +108,21 @@ async def test_response_set_attrs(get_old_response, get_new_response, transport) new_response = await get_new_response(transport) _test_response_set_attrs(old_response, new_response) + @pytest.mark.trio async def test_response_set_attrs_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_set_attrs(old_response, new_response) + def _test_response_block_size(old_response, new_response): assert old_response.block_size == new_response.block_size == 4096 old_response.block_size = 500 new_response.block_size = 500 assert old_response.block_size == new_response.block_size == 500 + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_block_size(get_old_response, get_new_response, transport): @@ -113,12 +130,14 @@ async def test_response_block_size(get_old_response, get_new_response, transport new_response = await get_new_response(transport) _test_response_block_size(old_response, new_response) + @pytest.mark.trio async def test_response_block_size_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_block_size(old_response, new_response) + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_body(get_old_response, get_new_response, transport): @@ -126,18 +145,25 @@ async def test_response_body(get_old_response, get_new_response, transport): new_response = await get_new_response(transport) assert old_response.body() == new_response.body() == b"Hello, world!" + @pytest.mark.trio async def test_response_body_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio() new_response = await get_new_response_trio() assert old_response.body() == new_response.body() == b"Hello, world!" + def _test_response_internal_response(old_response, new_response, port): - assert str(old_response.internal_response.url) == str(new_response.internal_response.url) == "http://localhost:{}/streams/basic".format(port) + assert ( + str(old_response.internal_response.url) + == str(new_response.internal_response.url) + == "http://localhost:{}/streams/basic".format(port) + ) old_response.internal_response = "foo" new_response.internal_response = "foo" assert old_response.internal_response == new_response.internal_response == "foo" + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_internal_response(get_old_response, get_new_response, transport, port): @@ -145,12 +171,14 @@ async def test_response_internal_response(get_old_response, get_new_response, tr new_response = await get_new_response(transport) _test_response_internal_response(old_response, new_response, port) + @pytest.mark.trio async def test_response_internal_response_trio(get_old_response_trio, get_new_response_trio, port): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_internal_response(old_response, new_response, port) + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_stream_download(get_old_response, get_new_response, transport): @@ -165,6 +193,7 @@ async def test_response_stream_download(get_old_response, get_new_response, tran assert old_string in b"Hello, world!" assert new_string in b"Hello, world!" + @pytest.mark.trio async def test_response_stream_download_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio(stream=True) @@ -174,12 +203,14 @@ async def test_response_stream_download_trio(get_old_response_trio, get_new_resp new_string = b"".join([part async for part in new_response.stream_download(pipeline=pipeline)]) assert old_string == new_string == b"Hello, world!" + def _test_response_request(old_response, new_response, port): assert old_response.request.url == new_response.request.url == "http://localhost:{}/streams/basic".format(port) old_response.request = "foo" new_response.request = "foo" assert old_response.request == new_response.request == "foo" + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_request(get_old_response, get_new_response, port, transport): @@ -187,18 +218,21 @@ async def test_response_request(get_old_response, get_new_response, port, transp new_response = await get_new_response(transport) _test_response_request(old_response, new_response, port) + @pytest.mark.trio async def test_response_request_trio(get_old_response_trio, get_new_response_trio, port): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_request(old_response, new_response, port) + def _test_response_status_code(old_response, new_response): assert old_response.status_code == new_response.status_code == 200 old_response.status_code = 202 new_response.status_code = 202 assert old_response.status_code == new_response.status_code == 202 + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_status_code(get_old_response, get_new_response, transport): @@ -206,18 +240,25 @@ async def test_response_status_code(get_old_response, get_new_response, transpor new_response = await get_new_response(transport) _test_response_status_code(old_response, new_response) + @pytest.mark.trio async def test_response_status_code_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_status_code(old_response, new_response) + def _test_response_headers(old_response, new_response): - assert set(old_response.headers.keys()) == set(new_response.headers.keys()) == set(["Content-Type", "Connection", "Server", "Date"]) + assert ( + set(old_response.headers.keys()) + == set(new_response.headers.keys()) + == set(["Content-Type", "Connection", "Server", "Date"]) + ) old_response.headers = {"Hello": "world!"} new_response.headers = {"Hello": "world!"} assert old_response.headers == new_response.headers == {"Hello": "world!"} + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_headers(get_old_response, get_new_response, transport): @@ -225,18 +266,21 @@ async def test_response_headers(get_old_response, get_new_response, transport): new_response = await get_new_response(transport) _test_response_headers(old_response, new_response) + @pytest.mark.trio async def test_response_headers_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_headers(old_response, new_response) + def _test_response_reason(old_response, new_response): assert old_response.reason == new_response.reason == "OK" old_response.reason = "Not OK" new_response.reason = "Not OK" assert old_response.reason == new_response.reason == "Not OK" + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_reason(get_old_response, get_new_response, transport): @@ -244,18 +288,21 @@ async def test_response_reason(get_old_response, get_new_response, transport): new_response = await get_new_response(transport) _test_response_reason(old_response, new_response) + @pytest.mark.trio async def test_response_reason_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_reason(old_response, new_response) + def _test_response_content_type(old_response, new_response): assert old_response.content_type == new_response.content_type == "text/html; charset=utf-8" old_response.content_type = "application/json" new_response.content_type = "application/json" assert old_response.content_type == new_response.content_type == "application/json" + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) async def test_response_content_type(get_old_response, get_new_response, transport): @@ -263,26 +310,28 @@ async def test_response_content_type(get_old_response, get_new_response, transpo new_response = await get_new_response(transport) _test_response_content_type(old_response, new_response) + @pytest.mark.trio async def test_response_content_type_trio(get_old_response_trio, get_new_response_trio): old_response = await get_old_response_trio() new_response = await get_new_response_trio() _test_response_content_type(old_response, new_response) + def _create_multiapart_request(http_request_class): class ResponsePolicy(object): def on_request(self, *args): return def on_response(self, request, response): - response.http_response.headers['x-ms-fun'] = 'true' + response.http_response.headers["x-ms-fun"] = "true" class AsyncResponsePolicy(object): def on_request(self, *args): return async def on_response(self, request, response): - response.http_response.headers['x-ms-async-fun'] = 'true' + response.http_response.headers["x-ms-async-fun"] = "true" req0 = http_request_class("DELETE", "/container0/blob0") req1 = http_request_class("DELETE", "/container1/blob1") @@ -290,6 +339,7 @@ async def on_response(self, request, response): request.set_multipart_mixed(req0, req1, policies=[ResponsePolicy(), AsyncResponsePolicy()]) return request + async def _test_parts(response): # hack the content type parts = [p async for p in response.parts()] @@ -297,13 +347,14 @@ async def _test_parts(response): parts0 = parts[0] assert parts0.status_code == 202 - assert parts0.headers['x-ms-fun'] == 'true' - assert parts0.headers['x-ms-async-fun'] == 'true' + assert parts0.headers["x-ms-fun"] == "true" + assert parts0.headers["x-ms-async-fun"] == "true" parts1 = parts[1] assert parts1.status_code == 404 - assert parts1.headers['x-ms-fun'] == 'true' - assert parts1.headers['x-ms-async-fun'] == 'true' + assert parts1.headers["x-ms-fun"] == "true" + assert parts1.headers["x-ms-async-fun"] == "true" + @pytest.mark.asyncio @pytest.mark.parametrize("transport", TRANSPORTS) diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_stream_responses_async.py b/sdk/core/azure-core/tests/async_tests/test_rest_stream_responses_async.py index 575bee305794..740b5d80fa62 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_stream_responses_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_stream_responses_async.py @@ -8,6 +8,7 @@ from azure.core.rest import HttpRequest from azure.core.exceptions import StreamClosedError, StreamConsumedError, ResponseNotReadError + @pytest.mark.asyncio async def test_iter_raw(client): request = HttpRequest("GET", "/streams/basic") @@ -17,6 +18,7 @@ async def test_iter_raw(client): raw += part assert raw == b"Hello, world!" + @pytest.mark.asyncio async def test_iter_raw_on_iterable(client): request = HttpRequest("GET", "/streams/iterable") @@ -27,6 +29,7 @@ async def test_iter_raw_on_iterable(client): raw += part assert raw == b"Hello, world!" + @pytest.mark.asyncio async def test_iter_with_error(client): request = HttpRequest("GET", "/errors/403") @@ -52,6 +55,7 @@ async def test_iter_with_error(client): raise ValueError("Should error before entering") assert response.is_closed + @pytest.mark.asyncio async def test_iter_bytes(client): request = HttpRequest("GET", "/streams/basic") @@ -66,6 +70,7 @@ async def test_iter_bytes(client): assert response.is_closed assert raw == b"Hello, world!" + @pytest.mark.skip(reason="We've gotten rid of iter_text for now") @pytest.mark.asyncio async def test_iter_text(client): @@ -77,6 +82,7 @@ async def test_iter_text(client): content += part assert content == "Hello, world!" + @pytest.mark.skip(reason="We've gotten rid of iter_lines for now") @pytest.mark.asyncio async def test_iter_lines(client): @@ -102,6 +108,7 @@ async def test_streaming_response(client): assert response.content == b"Hello, world!" assert response.is_closed + @pytest.mark.asyncio async def test_cannot_read_after_stream_consumed(port, client): request = HttpRequest("GET", "/streams/basic") @@ -127,6 +134,7 @@ async def test_cannot_read_after_response_closed(port, client): assert "".format(port) in str(ex.value) assert "can no longer be read or streamed, since the response has already been closed" in str(ex.value) + @pytest.mark.asyncio async def test_decompress_plain_no_header(client): # thanks to Xiang Yan for this test! @@ -140,6 +148,7 @@ async def test_decompress_plain_no_header(client): await response.read() assert response.content == b"test" + @pytest.mark.asyncio async def test_compress_plain_no_header(client): # thanks to Xiang Yan for this test! @@ -154,6 +163,7 @@ async def test_compress_plain_no_header(client): data += d assert data == b"test" + @pytest.mark.asyncio async def test_iter_read_back_and_forth(client): # thanks to McCoy Patiño for this test! @@ -174,6 +184,7 @@ async def test_iter_read_back_and_forth(client): with pytest.raises(ResponseNotReadError): response.text() + @pytest.mark.asyncio async def test_stream_with_return_pipeline_response(client): request = HttpRequest("GET", "/basic/string") @@ -185,9 +196,10 @@ async def test_stream_with_return_pipeline_response(client): parts = [] async for part in pipeline_response.http_response.iter_bytes(): parts.append(part) - assert parts == [b'Hello, world!'] + assert parts == [b"Hello, world!"] await client.close() + @pytest.mark.asyncio async def test_error_reading(client): request = HttpRequest("GET", "/errors/403") @@ -203,6 +215,7 @@ async def test_error_reading(client): assert response.content == b"" await client.close() + @pytest.mark.asyncio async def test_pass_kwarg_to_iter_bytes(client): request = HttpRequest("GET", "/basic/string") @@ -210,6 +223,7 @@ async def test_pass_kwarg_to_iter_bytes(client): async for part in response.iter_bytes(chunk_size=5): assert part + @pytest.mark.asyncio async def test_pass_kwarg_to_iter_raw(client): request = HttpRequest("GET", "/basic/string") @@ -217,6 +231,7 @@ async def test_pass_kwarg_to_iter_raw(client): async for part in response.iter_raw(chunk_size=5): assert part + @pytest.mark.asyncio async def test_decompress_compressed_header(client): # expect plain text @@ -227,6 +242,7 @@ async def test_decompress_compressed_header(client): assert response.content == content assert response.text() == "hello world" + @pytest.mark.asyncio async def test_deflate_decompress_compressed_header(client): # expect plain text @@ -237,6 +253,7 @@ async def test_deflate_decompress_compressed_header(client): assert response.content == content assert response.text() == "hi there" + @pytest.mark.asyncio async def test_decompress_compressed_header_stream(client): # expect plain text @@ -247,6 +264,7 @@ async def test_decompress_compressed_header_stream(client): assert response.content == content assert response.text() == "hello world" + @pytest.mark.asyncio async def test_decompress_compressed_header_stream_body_content(client): # expect plain text diff --git a/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py b/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py index 813864544407..38075eef44ae 100644 --- a/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py +++ b/sdk/core/azure-core/tests/async_tests/test_rest_trio_transport.py @@ -10,12 +10,14 @@ from utils import readonly_checks import pytest + @pytest.fixture async def client(port): async with TrioRequestsTransport() as transport: async with AsyncTestRestClient(port, transport=transport) as client: yield client + @pytest.mark.trio async def test_async_gen_data(client, port): class AsyncGen: @@ -31,15 +33,17 @@ async def __anext__(self): except StopIteration: raise StopAsyncIteration - request = HttpRequest('GET', 'http://localhost:{}/basic/anything'.format(port), content=AsyncGen()) + request = HttpRequest("GET", "http://localhost:{}/basic/anything".format(port), content=AsyncGen()) response = await client.send_request(request) - assert response.json()['data'] == "azerty" + assert response.json()["data"] == "azerty" + @pytest.mark.trio async def test_send_data(port, client): - request = HttpRequest('PUT', 'http://localhost:{}/basic/anything'.format(port), content=b"azerty") + request = HttpRequest("PUT", "http://localhost:{}/basic/anything".format(port), content=b"azerty") response = await client.send_request(request) - assert response.json()['data'] == "azerty" + assert response.json()["data"] == "azerty" + @pytest.mark.trio async def test_readonly(client): @@ -49,8 +53,10 @@ async def test_readonly(client): assert isinstance(response, RestTrioRequestsTransportResponse) from azure.core.pipeline.transport import TrioRequestsTransportResponse + readonly_checks(response, old_response_class=TrioRequestsTransportResponse) + @pytest.mark.trio async def test_decompress_compressed_header(client): # expect plain text @@ -61,6 +67,7 @@ async def test_decompress_compressed_header(client): assert response.content == content assert response.text() == "hello world" + @pytest.mark.trio async def test_decompress_compressed_header_stream(client): # expect plain text @@ -71,6 +78,7 @@ async def test_decompress_compressed_header_stream(client): assert response.content == content assert response.text() == "hello world" + @pytest.mark.trio async def test_decompress_compressed_header_stream_body_content(client): # expect plain text diff --git a/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py b/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py index 270cc5c5684d..d86bdaad05f7 100644 --- a/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py @@ -16,7 +16,7 @@ ServiceRequestError, ServiceRequestTimeoutError, ServiceResponseError, - ServiceResponseTimeoutError + ServiceResponseTimeoutError, ) from azure.core.pipeline.policies import ( AsyncRetryPolicy, @@ -34,6 +34,7 @@ from itertools import product from utils import HTTP_REQUESTS + def test_retry_code_class_variables(): retry_policy = AsyncRetryPolicy() assert retry_policy._RETRY_CODES is not None @@ -41,13 +42,10 @@ def test_retry_code_class_variables(): assert 429 in retry_policy._RETRY_CODES assert 501 not in retry_policy._RETRY_CODES + def test_retry_types(): history = ["1", "2", "3"] - settings = { - 'history': history, - 'backoff': 1, - 'max_backoff': 10 - } + settings = {"history": history, "backoff": 1, "max_backoff": 10} retry_policy = AsyncRetryPolicy() backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 @@ -60,7 +58,8 @@ def test_retry_types(): backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 -@pytest.mark.parametrize("retry_after_input,http_request", product(['0', '800', '1000', '1200'], HTTP_REQUESTS)) + +@pytest.mark.parametrize("retry_after_input,http_request", product(["0", "800", "1000", "1200"], HTTP_REQUESTS)) def test_retry_after(retry_after_input, http_request): retry_policy = AsyncRetryPolicy() request = http_request("GET", "http://localhost") @@ -69,7 +68,7 @@ def test_retry_after(retry_after_input, http_request): pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) seconds = float(retry_after_input) - assert retry_after == seconds/1000.0 + assert retry_after == seconds / 1000.0 response.headers.pop("retry-after-ms") response.headers["Retry-After"] = retry_after_input retry_after = retry_policy.get_retry_after(pipeline_response) @@ -78,7 +77,8 @@ def test_retry_after(retry_after_input, http_request): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -@pytest.mark.parametrize("retry_after_input,http_request", product(['0', '800', '1000', '1200'], HTTP_REQUESTS)) + +@pytest.mark.parametrize("retry_after_input,http_request", product(["0", "800", "1000", "1200"], HTTP_REQUESTS)) def test_x_ms_retry_after(retry_after_input, http_request): retry_policy = AsyncRetryPolicy() request = http_request("GET", "http://localhost") @@ -87,7 +87,7 @@ def test_x_ms_retry_after(retry_after_input, http_request): pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) seconds = float(retry_after_input) - assert retry_after == seconds/1000.0 + assert retry_after == seconds / 1000.0 response.headers.pop("x-ms-retry-after-ms") response.headers["Retry-After"] = retry_after_input retry_after = retry_policy.get_retry_after(pipeline_response) @@ -96,16 +96,20 @@ def test_x_ms_retry_after(retry_after_input, http_request): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_retry_on_429(http_request): class MockTransport(AsyncHttpTransport): def __init__(self): self._count = 0 + async def __aexit__(self, exc_type, exc_val, exc_tb): pass + async def close(self): pass + async def open(self): pass @@ -115,23 +119,27 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe response.status_code = 429 return response - http_request = http_request('GET', 'http://localhost/') - http_retry = AsyncRetryPolicy(retry_total = 1) + http_request = http_request("GET", "http://localhost/") + http_retry = AsyncRetryPolicy(retry_total=1) transport = MockTransport() pipeline = AsyncPipeline(transport, [http_retry]) await pipeline.run(http_request) assert transport._count == 2 + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_no_retry_on_201(http_request): class MockTransport(AsyncHttpTransport): def __init__(self): self._count = 0 + async def __aexit__(self, exc_type, exc_val, exc_tb): pass + async def close(self): pass + async def open(self): pass @@ -143,31 +151,35 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe response.headers = headers return response - http_request = http_request('GET', 'http://localhost/') - http_retry = AsyncRetryPolicy(retry_total = 1) + http_request = http_request("GET", "http://localhost/") + http_retry = AsyncRetryPolicy(retry_total=1) transport = MockTransport() pipeline = AsyncPipeline(transport, [http_retry]) await pipeline.run(http_request) assert transport._count == 1 + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_retry_seekable_stream(http_request): class MockTransport(AsyncHttpTransport): def __init__(self): self._first = True + async def __aexit__(self, exc_type, exc_val, exc_tb): pass + async def close(self): pass + async def open(self): pass async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineResponse if self._first: self._first = False - request.body.seek(0,2) - raise AzureError('fail on first') + request.body.seek(0, 2) + raise AzureError("fail on first") position = request.body.tell() assert position == 0 response = HttpResponse(request, None) @@ -175,22 +187,26 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe return response data = BytesIO(b"Lots of dataaaa") - http_request = http_request('GET', 'http://localhost/') + http_request = http_request("GET", "http://localhost/") http_request.set_streamed_data_body(data) - http_retry = AsyncRetryPolicy(retry_total = 1) + http_retry = AsyncRetryPolicy(retry_total=1) pipeline = AsyncPipeline(MockTransport(), [http_retry]) await pipeline.run(http_request) + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_retry_seekable_file(http_request): class MockTransport(AsyncHttpTransport): def __init__(self): self._first = True + async def __aexit__(self, exc_type, exc_val, exc_tb): pass + async def close(self): pass + async def open(self): pass @@ -199,12 +215,12 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe self._first = False for value in request.files.values(): name, body = value[0], value[1] - if name and body and hasattr(body, 'read'): - body.seek(0,2) - raise AzureError('fail on first') + if name and body and hasattr(body, "read"): + body.seek(0, 2) + raise AzureError("fail on first") for value in request.files.values(): name, body = value[0], value[1] - if name and body and hasattr(body, 'read'): + if name and body and hasattr(body, "read"): position = body.tell() assert not position response = HttpResponse(request, None) @@ -212,15 +228,15 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe return response file = tempfile.NamedTemporaryFile(delete=False) - file.write(b'Lots of dataaaa') + file.write(b"Lots of dataaaa") file.close() - http_request = http_request('GET', 'http://localhost/') - headers = {'Content-Type': "multipart/form-data"} + http_request = http_request("GET", "http://localhost/") + headers = {"Content-Type": "multipart/form-data"} http_request.headers = headers - with open(file.name, 'rb') as f: + with open(file.name, "rb") as f: form_data_content = { - 'fileContent': f, - 'fileName': f.name, + "fileContent": f, + "fileName": f.name, } http_request.set_formdata_body(form_data_content) http_retry = AsyncRetryPolicy(retry_total=1) @@ -272,8 +288,10 @@ async def send(request, **kwargs): await pipeline.run(http_request("GET", "http://localhost/")) assert transport.send.call_count == 1, "policy should not retry: its first send succeeded" + combinations = [(ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)] + @pytest.mark.asyncio @pytest.mark.parametrize( "combinations,http_request", diff --git a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py index f9be72d1a9da..ad2e5d39cb02 100644 --- a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py @@ -14,6 +14,7 @@ import pytest from utils import request_and_responses_product, ASYNC_HTTP_RESPONSES, create_http_response + @pytest.mark.asyncio @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNC_HTTP_RESPONSES)) async def test_connection_error_response(http_request, http_response): @@ -32,18 +33,20 @@ def __init__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): pass + async def close(self): pass + async def open(self): pass async def send(self, request, **kwargs): - request = http_request('GET', 'http://localhost/') + request = http_request("GET", "http://localhost/") response = create_http_response(http_response, request, None) response.status_code = 200 return response - class MockContent(): + class MockContent: def __init__(self): self._first = True @@ -53,7 +56,7 @@ async def read(self, block_size): raise ConnectionError return None - class MockInternalResponse(): + class MockInternalResponse: def __init__(self): self.headers = {} self.content = MockContent() @@ -65,15 +68,16 @@ class AsyncMock(mock.MagicMock): async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) - http_request = http_request('GET', 'http://localhost/') + http_request = http_request("GET", "http://localhost/") pipeline = AsyncPipeline(MockTransport()) http_response = create_http_response(http_response, http_request, None) http_response.internal_response = MockInternalResponse() stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False) - with mock.patch('asyncio.sleep', new_callable=AsyncMock): + with mock.patch("asyncio.sleep", new_callable=AsyncMock): with pytest.raises(ConnectionError): await stream.__anext__() + @pytest.mark.asyncio @pytest.mark.parametrize("http_response", ASYNC_HTTP_RESPONSES) async def test_response_streaming_error_behavior(http_response): diff --git a/sdk/core/azure-core/tests/async_tests/test_streaming_async.py b/sdk/core/azure-core/tests/async_tests/test_streaming_async.py index 1385126b5ce1..b7b4f8cc17dc 100644 --- a/sdk/core/azure-core/tests/async_tests/test_streaming_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_streaming_async.py @@ -29,6 +29,7 @@ from azure.core.exceptions import DecodeError from utils import HTTP_REQUESTS + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_decompress_plain_no_header(http_request): @@ -45,9 +46,10 @@ async def test_decompress_plain_no_header(http_request): content = b"" async for d in data: content += d - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_compress_plain_no_header(http_request): @@ -64,9 +66,10 @@ async def test_compress_plain_no_header(http_request): content = b"" async for d in data: content += d - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_decompress_compressed_no_header(http_request): @@ -84,11 +87,12 @@ async def test_decompress_compressed_no_header(http_request): async for d in data: content += d try: - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert False except UnicodeDecodeError: pass + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_compress_compressed_no_header(http_request): @@ -106,7 +110,7 @@ async def test_compress_compressed_no_header(http_request): async for d in data: content += d try: - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert False except UnicodeDecodeError: pass @@ -118,6 +122,7 @@ async def test_compress_compressed_no_header(http_request): async def test_decompress_plain_header(http_request): # expect error import zlib + account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) @@ -135,6 +140,7 @@ async def test_decompress_plain_header(http_request): except (zlib.error, DecodeError): pass + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_compress_plain_header(http_request): @@ -151,9 +157,10 @@ async def test_compress_plain_header(http_request): content = b"" async for d in data: content += d - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_decompress_compressed_header(http_request): @@ -170,9 +177,10 @@ async def test_decompress_compressed_header(http_request): content = b"" async for d in data: content += d - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.live_test_only @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @@ -191,7 +199,7 @@ async def test_compress_compressed_header(http_request): async for d in data: content += d try: - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert False except UnicodeDecodeError: pass diff --git a/sdk/core/azure-core/tests/async_tests/test_testserver_async.py b/sdk/core/azure-core/tests/async_tests/test_testserver_async.py index d6557b2b3e7c..bf294e54a105 100644 --- a/sdk/core/azure-core/tests/async_tests/test_testserver_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_testserver_async.py @@ -26,8 +26,10 @@ import pytest from azure.core.pipeline.transport import AioHttpTransport from utils import HTTP_REQUESTS + """This file does a simple call to the testserver to make sure we can use the testserver""" + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_smoke(port, http_request): diff --git a/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py b/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py index eefda90c3737..b668e49ce000 100644 --- a/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py @@ -56,7 +56,7 @@ async def make_request(self, numb_times, **kwargs): return None response = self.pipeline.run(self.request, **kwargs) await self.get_foo(merge_span=True) - kwargs['merge_span'] = True + kwargs["merge_span"] = True await self.make_request(numb_times - 1, **kwargs) return response @@ -77,7 +77,7 @@ async def get_foo(self): async def check_name_is_different(self): time.sleep(0.001) - @distributed_trace_async(tracing_attributes={'foo': 'bar'}) + @distributed_trace_async(tracing_attributes={"foo": "bar"}) async def tracing_attr(self): time.sleep(0.001) @@ -92,7 +92,6 @@ async def raising_exception(self): @pytest.mark.usefixtures("fake_span") class TestAsyncDecorator(object): - @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_decorator_tracing_attr(self, http_request): @@ -104,8 +103,7 @@ async def test_decorator_tracing_attr(self, http_request): assert parent.children[0].name == "MockClient.__init__" assert parent.children[1].name == "MockClient.tracing_attr" assert parent.children[1].kind == SpanKind.INTERNAL - assert parent.children[1].attributes == {'foo': 'bar'} - + assert parent.children[1].attributes == {"foo": "bar"} @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @@ -146,7 +144,6 @@ async def test_used(self, http_request): assert parent.children[2].name == "MockClient.get_foo" assert not parent.children[2].children - @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_span_merge_span(self, http_request): @@ -163,7 +160,6 @@ async def test_span_merge_span(self, http_request): assert parent.children[2].name == "MockClient.no_merge_span_method" assert parent.children[2].children[0].name == "MockClient.get_foo" - @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_span_complicated(self, http_request): @@ -189,8 +185,7 @@ async def test_span_complicated(self, http_request): @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_span_with_exception(self, http_request): - """Assert that if an exception is raised, the next sibling method is actually a sibling span. - """ + """Assert that if an exception is raised, the next sibling method is actually a sibling span.""" with FakeSpan(name="parent") as parent: client = MockClient(http_request) try: @@ -203,5 +198,5 @@ async def test_span_with_exception(self, http_request): assert parent.children[0].name == "MockClient.__init__" assert parent.children[1].name == "MockClient.raising_exception" # Exception should propagate status for Opencensus - assert parent.children[1].status == 'Something went horribly wrong here' + assert parent.children[1].status == "Something went horribly wrong here" assert parent.children[2].name == "MockClient.get_foo" diff --git a/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py b/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py index 04d11c2f9010..23e1e6e867bf 100644 --- a/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,12 +22,13 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from azure.core.pipeline.transport import ( AioHttpTransport, AioHttpTransportResponse, AsyncioRequestsTransport, - TrioRequestsTransport) + TrioRequestsTransport, +) import aiohttp import trio @@ -36,6 +37,7 @@ from utils import HTTP_REQUESTS, AIOHTTP_TRANSPORT_RESPONSES, create_transport_response from azure.core.pipeline._tools import is_rest + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_basic_aiohttp(port, http_request): @@ -48,6 +50,7 @@ async def test_basic_aiohttp(port, http_request): assert sender.session is None assert isinstance(response.status_code, int) + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_aiohttp_auto_headers(port, http_request): @@ -56,7 +59,8 @@ async def test_aiohttp_auto_headers(port, http_request): async with AioHttpTransport() as sender: response = await sender.send(request) auto_headers = response.internal_response.request_info.headers - assert 'Content-Type' not in auto_headers + assert "Content-Type" not in auto_headers + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @@ -69,6 +73,7 @@ async def test_basic_async_requests(port, http_request): assert isinstance(response.status_code, int) + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_conf_async_requests(port, http_request): @@ -80,9 +85,9 @@ async def test_conf_async_requests(port, http_request): assert isinstance(response.status_code, int) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_conf_async_trio_requests(port, http_request): - async def do(): request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with TrioRequestsTransport() as sender: @@ -104,11 +109,7 @@ def __init__(self, body_bytes, headers=None): req_response = MockAiohttpClientResponse(body_bytes, headers) - response = create_transport_response( - http_response, - None, # Don't need a request here - req_response - ) + response = create_transport_response(http_response, None, req_response) # Don't need a request here response._content = body_bytes return response @@ -120,14 +121,11 @@ async def test_aiohttp_response_text(http_response): for encoding in ["utf-8", "utf-8-sig", None]: - res = _create_aiohttp_response( - http_response, - b'\xef\xbb\xbf56', - {'Content-Type': 'text/plain'} - ) + res = _create_aiohttp_response(http_response, b"\xef\xbb\xbf56", {"Content-Type": "text/plain"}) if is_rest(http_response): await res.read() - assert res.text(encoding) == '56', "Encoding {} didn't work".format(encoding) + assert res.text(encoding) == "56", "Encoding {} didn't work".format(encoding) + @pytest.mark.asyncio @pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) @@ -144,21 +142,25 @@ async def test_aiohttp_response_decompression(http_response): b"\xe4o\xc6T\xdeVw\x9dgL\x7f\xe0n\xc0\x91q\x02'w0b\x98JZe^\x89|\xce\x9b" b"\x0e\xcbW\x8a\x97\xf4X\x97\xc8\xbf\xfeYU\x1d\xc2\x85\xfc\xf4@\xb7\xbe" b"\xf7+&$\xf6\xa9\x8a\xcb\x96\xdc\xef\xff\xaa\xa1\x1c\xf9$\x01\x00\x00", - {'Content-Type': 'text/plain', 'Content-Encoding':"gzip"} + {"Content-Type": "text/plain", "Content-Encoding": "gzip"}, ) # cSpell:enable body = res.body() - expect = b'{"id":"e7877039-1376-4dcd-9b0a-192897cff780","createdDateTimeUtc":' \ - b'"2021-05-07T17:35:36.3121065Z","lastActionDateTimeUtc":' \ - b'"2021-05-07T17:35:36.3121069Z","status":"NotStarted",' \ - b'"summary":{"total":0,"failed":0,"success":0,"inProgress":0,' \ - b'"notYetStarted":0,"cancelled":0,"totalCharacterCharged":0}}' + expect = ( + b'{"id":"e7877039-1376-4dcd-9b0a-192897cff780","createdDateTimeUtc":' + b'"2021-05-07T17:35:36.3121065Z","lastActionDateTimeUtc":' + b'"2021-05-07T17:35:36.3121069Z","status":"NotStarted",' + b'"summary":{"total":0,"failed":0,"success":0,"inProgress":0,' + b'"notYetStarted":0,"cancelled":0,"totalCharacterCharged":0}}' + ) assert res.body() == expect, "Decompression didn't work" + @pytest.mark.asyncio @pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) async def test_aiohttp_response_decompression_negative(http_response): import zlib + # cSpell:disable res = _create_aiohttp_response( http_response, @@ -170,19 +172,16 @@ async def test_aiohttp_response_decompression_negative(http_response): b"\xe4o\xc6T\xdeVw\x9dgL\x7f\xe0n\xc0\x91q\x02'w0b\x98JZe^\x89|\xce\x9b" b"\x0e\xcbW\x8a\x97\xf4X\x97\xc8\xbf\xfeYU\x1d\xc2\x85\xfc\xf4@\xb7\xbe" b"\xf7+&$\xf6\xa9\x8a\xcb\x96\xdc\xef\xff\xaa\xa1\x1c\xf9$\x01\x00\x00", - {'Content-Type': 'text/plain', 'Content-Encoding':"gzip"} + {"Content-Type": "text/plain", "Content-Encoding": "gzip"}, ) # cSpell:enable with pytest.raises(zlib.error): body = res.body() + @pytest.mark.parametrize("http_response", AIOHTTP_TRANSPORT_RESPONSES) def test_repr(http_response): - res = _create_aiohttp_response( - http_response, - b'\xef\xbb\xbf56', - {} - ) + res = _create_aiohttp_response(http_response, b"\xef\xbb\xbf56", {}) res.content_type = "text/plain" class_name = "AsyncHttpResponse" if is_rest(http_response) else "AioHttpTransportResponse" diff --git a/sdk/core/azure-core/tests/conftest.py b/sdk/core/azure-core/tests/conftest.py index 0937d265b648..e0790cca3c9f 100644 --- a/sdk/core/azure-core/tests/conftest.py +++ b/sdk/core/azure-core/tests/conftest.py @@ -39,10 +39,12 @@ try: from azure.core.tracing.ext.opencensus_span import OpenCensusSpan from opencensus.trace.tracer import Tracer + Tracer() except ImportError: pass + def is_port_available(port_num): req = urllib.request.Request("http://localhost:{}/health".format(port_num)) try: @@ -50,6 +52,7 @@ def is_port_available(port_num): except Exception as e: return True + def get_port(): count = 3 for _ in range(count): @@ -58,24 +61,26 @@ def get_port(): return port_num raise TypeError("Tried {} times, can't find an open port".format(count)) + @pytest.fixture def port(): return os.environ["FLASK_PORT"] + def start_testserver(): port = get_port() os.environ["FLASK_APP"] = "coretestserver" os.environ["FLASK_PORT"] = str(port) - if platform.python_implementation() == 'PyPy': + if platform.python_implementation() == "PyPy": # pypy is now getting mad at us for some of our encoding / text, so need # to set these additional env vars for pypy os.environ["LC_ALL"] = "C.UTF-8" os.environ["LANG"] = "C.UTF-8" cmd = "flask run -p {}".format(port) - if os.name == 'nt': #On windows, subprocess creation works without being in the shell + if os.name == "nt": # On windows, subprocess creation works without being in the shell child_process = subprocess.Popen(cmd, env=dict(os.environ)) else: - #On linux, have to set shell=True + # On linux, have to set shell=True child_process = subprocess.Popen(cmd, shell=True, preexec_fn=os.setsid, env=dict(os.environ)) count = 5 for _ in range(count): @@ -84,12 +89,14 @@ def start_testserver(): time.sleep(1) raise ValueError("Didn't start!") + def terminate_testserver(process): - if os.name == 'nt': + if os.name == "nt": process.kill() else: os.killpg(os.getpgid(process.pid), signal.SIGTERM) # Send the signal to all the process groups + @pytest.fixture(autouse=True, scope="package") def testserver(): """Start the Autorest testserver.""" @@ -97,6 +104,7 @@ def testserver(): yield terminate_testserver(server) + @pytest.fixture def client(port): return TestRestClient(port) diff --git a/sdk/core/azure-core/tests/rest_client.py b/sdk/core/azure-core/tests/rest_client.py index 0589ab1bf6bd..25765faeac7a 100644 --- a/sdk/core/azure-core/tests/rest_client.py +++ b/sdk/core/azure-core/tests/rest_client.py @@ -1,4 +1,3 @@ - # -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. @@ -31,18 +30,14 @@ class TestRestClientConfiguration(Configuration): - def __init__( - self, **kwargs - ): + def __init__(self, **kwargs): # type: (...) -> None super(TestRestClientConfiguration, self).__init__(**kwargs) kwargs.setdefault("sdk_moniker", "autorestswaggerbatfileservice/1.0.0b1") self._configure(**kwargs) - def _configure( - self, **kwargs - ): + def _configure(self, **kwargs): # type: (...) -> None self.user_agent_policy = kwargs.get("user_agent_policy") or policies.UserAgentPolicy(**kwargs) self.headers_policy = kwargs.get("headers_policy") or policies.HeadersPolicy(**kwargs) @@ -54,15 +49,11 @@ def _configure( self.redirect_policy = kwargs.get("redirect_policy") or policies.RedirectPolicy(**kwargs) self.authentication_policy = kwargs.get("authentication_policy") -class TestRestClient(object): +class TestRestClient(object): def __init__(self, port, **kwargs): self._config = TestRestClientConfiguration(**kwargs) - self._client = PipelineClient( - base_url="http://localhost:{}/".format(port), - config=self._config, - **kwargs - ) + self._client = PipelineClient(base_url="http://localhost:{}/".format(port), config=self._config, **kwargs) def send_request(self, request, **kwargs): """Runs the network request through the client's chained policies. diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index 1c8f0123322e..2b115046e32c 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -245,6 +245,7 @@ def raise_the_second_time(*args, **kwargs): raise_the_second_time.calls = 1 return Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'}) raise TestException() + raise_the_second_time.calls = 0 policy = TestPolicy(credential, "scope") @@ -270,7 +271,7 @@ def test_key_vault_regression(http_request): assert policy._credential is credential headers = {} - token = "alphanums" # cspell:disable-line + token = "alphanums" # cspell:disable-line policy._update_headers(headers, token) assert headers["Authorization"] == "Bearer " + token @@ -321,31 +322,51 @@ def test_azure_key_credential_updates(): credential.update(api_key) assert credential.key == api_key + combinations = [ ("sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"), ("?sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"), - ("sig=test_signature", "https://test_sas_credential?sig=test_signature", "https://test_sas_credential?sig=test_signature"), - ("?sig=test_signature", "https://test_sas_credential?sig=test_signature", "https://test_sas_credential?sig=test_signature"), + ( + "sig=test_signature", + "https://test_sas_credential?sig=test_signature", + "https://test_sas_credential?sig=test_signature", + ), + ( + "?sig=test_signature", + "https://test_sas_credential?sig=test_signature", + "https://test_sas_credential?sig=test_signature", + ), ("sig=test_signature", "https://test_sas_credential?", "https://test_sas_credential?sig=test_signature"), ("?sig=test_signature", "https://test_sas_credential?", "https://test_sas_credential?sig=test_signature"), - ("sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"), - ("?sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"), + ( + "sig=test_signature", + "https://test_sas_credential?foo=bar", + "https://test_sas_credential?foo=bar&sig=test_signature", + ), + ( + "?sig=test_signature", + "https://test_sas_credential?foo=bar", + "https://test_sas_credential?foo=bar&sig=test_signature", + ), ] + @pytest.mark.parametrize("combinations,http_request", product(combinations, HTTP_REQUESTS)) def test_azure_sas_credential_policy(combinations, http_request): """Tests to see if we can create an AzureSasCredentialPolicy""" sas, url, expected_url = combinations + def verify_authorization(request): assert request.url == expected_url - transport=Mock(send=verify_authorization) + transport = Mock(send=verify_authorization) credential = AzureSasCredential(sas) credential_policy = AzureSasCredentialPolicy(credential=credential) pipeline = Pipeline(transport=transport, policies=[credential_policy]) pipeline.run(http_request("GET", url)) + def test_azure_sas_credential_updates(): """Tests AzureSasCredential updates""" sas = "original" @@ -357,12 +378,14 @@ def test_azure_sas_credential_updates(): credential.update(sas) assert credential.signature == sas + def test_azure_sas_credential_policy_raises(): """Tests AzureSasCredential and AzureSasCredentialPolicy raises with non-string input parameters.""" sas = 1234 with pytest.raises(TypeError): credential = AzureSasCredential(sas) + def test_azure_named_key_credential(): cred = AzureNamedKeyCredential("sample_name", "samplekey") @@ -375,6 +398,7 @@ def test_azure_named_key_credential(): assert cred.named_key.key == "newkey" assert isinstance(cred.named_key, tuple) + def test_azure_named_key_credential_raises(): with pytest.raises(TypeError, match="Both name and key must be strings."): cred = AzureNamedKeyCredential("sample_name", 123345) diff --git a/sdk/core/azure-core/tests/test_base_polling.py b/sdk/core/azure-core/tests/test_base_polling.py index c0a403de289a..b470f088ee75 100644 --- a/sdk/core/azure-core/tests/test_base_polling.py +++ b/sdk/core/azure-core/tests/test_base_polling.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import base64 import datetime import json @@ -52,6 +52,7 @@ from azure.core.pipeline._tools import is_rest from rest_client import TestRestClient + class SimpleResource: """An implementation of Python 3 SimpleNamespace. Used to deserialize resource objects from response bodies where @@ -69,25 +70,33 @@ def __repr__(self): def __eq__(self, other): return self.__dict__ == other.__dict__ + class BadEndpointError(Exception): pass -TEST_NAME = 'foo' -RESPONSE_BODY = {'properties':{'provisioningState': 'InProgress'}} -ASYNC_BODY = json.dumps({ 'status': 'Succeeded' }) -ASYNC_URL = 'http://dummyurlFromAzureAsyncOPHeader_Return200' -LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) -LOCATION_URL = 'http://dummyurlurlFromLocationHeader_Return200' -RESOURCE_BODY = json.dumps({ 'name': TEST_NAME }) -RESOURCE_URL = 'http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1' -ERROR = 'http://dummyurl_ReturnError' + +TEST_NAME = "foo" +RESPONSE_BODY = {"properties": {"provisioningState": "InProgress"}} +ASYNC_BODY = json.dumps({"status": "Succeeded"}) +ASYNC_URL = "http://dummyurlFromAzureAsyncOPHeader_Return200" +LOCATION_BODY = json.dumps({"name": TEST_NAME}) +LOCATION_URL = "http://dummyurlurlFromLocationHeader_Return200" +RESOURCE_BODY = json.dumps({"name": TEST_NAME}) +RESOURCE_URL = "http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1" +ERROR = "http://dummyurl_ReturnError" POLLING_STATUS = 200 CLIENT = PipelineClient("http://example.org") CLIENT.http_request_type = None CLIENT.http_response_type = None + + def mock_run(client_self, request, **kwargs): - return TestBasePolling.mock_update(client_self.http_request_type, client_self.http_response_type, request.url, request.headers) + return TestBasePolling.mock_update( + client_self.http_request_type, client_self.http_response_type, request.url, request.headers + ) + + CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -103,21 +112,23 @@ def pipeline_client_builder(): send will receive "request" and kwargs as any transport layer """ + def create_client(send_cb): class TestHttpTransport(HttpTransport): - def open(self): pass - def close(self): pass - def __exit__(self, *args, **kwargs): pass + def open(self): + pass + + def close(self): + pass + + def __exit__(self, *args, **kwargs): + pass def send(self, request, **kwargs): return send_cb(request, **kwargs) - return PipelineClient( - 'http://example.org/', - pipeline=Pipeline( - transport=TestHttpTransport() - ) - ) + return PipelineClient("http://example.org/", pipeline=Pipeline(transport=TestHttpTransport())) + return create_client @@ -125,6 +136,7 @@ def send(self, request, **kwargs): def deserialization_cb(): def cb(pipeline_response): return json.loads(pipeline_response.http_response.text()) + return cb @@ -142,15 +154,13 @@ def _callback(http_response, headers={}): None, response, ) - polling._pipeline_response = PipelineResponse( - None, - response, - PipelineContext(None) - ) + polling._pipeline_response = PipelineResponse(None, response, PipelineContext(None)) polling._initial_response = polling._pipeline_response return polling + return _callback + @pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) def test_base_polling_continuation_token(client, polling_response, http_response): polling = polling_response(http_response) @@ -166,166 +176,138 @@ def test_base_polling_continuation_token(client, polling_response, http_response new_polling = LROBasePolling() new_polling.initialize(*polling_args) + @pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) def test_delay_extraction_int(polling_response, http_response): polling = polling_response(http_response, {"Retry-After": "10"}) assert polling._extract_delay() == 10 -@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="https://stackoverflow.com/questions/11146725/isinstance-and-mocking") +@pytest.mark.skipif( + platform.python_implementation() == "PyPy", + reason="https://stackoverflow.com/questions/11146725/isinstance-and-mocking", +) @pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) def test_delay_extraction_httpdate(polling_response, http_response): polling = polling_response(http_response, {"Retry-After": "Mon, 20 Nov 1995 19:12:08 -0500"}) - from datetime import datetime as basedatetime - now_mock_datetime = datetime.datetime(1995, 11, 20, 18, 12, 8, tzinfo=_FixedOffset(-5*60)) - with mock.patch('datetime.datetime') as mock_datetime: + + now_mock_datetime = datetime.datetime(1995, 11, 20, 18, 12, 8, tzinfo=_FixedOffset(-5 * 60)) + with mock.patch("datetime.datetime") as mock_datetime: mock_datetime.now.return_value = now_mock_datetime mock_datetime.side_effect = lambda *args, **kw: basedatetime(*args, **kw) - assert polling._extract_delay() == 60*60 # one hour in seconds + assert polling._extract_delay() == 60 * 60 # one hour in seconds assert str(mock_datetime.now.call_args[0][0]) == "" + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) def test_post(pipeline_client_builder, deserialization_cb, http_request, http_response): - # Test POST LRO with both Location and Operation-Location + # Test POST LRO with both Location and Operation-Location + + # The initial response contains both Location and Operation-Location, a 202 and no Body + initial_response = TestBasePolling.mock_send( + http_request, + http_response, + "POST", + 202, + { + "location": "http://example.org/location", + "operation-location": "http://example.org/async_monitor", + }, + "", + ) - # The initial response contains both Location and Operation-Location, a 202 and no Body - initial_response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', - 202, - { - 'location': 'http://example.org/location', - 'operation-location': 'http://example.org/async_monitor', - }, - '' - ) + def send(request, **kwargs): + assert request.method == "GET" + + if request.url == "http://example.org/location": + return TestBasePolling.mock_send( + http_request, http_response, "GET", 200, body={"location_result": True} + ).http_response + elif request.url == "http://example.org/async_monitor": + return TestBasePolling.mock_send( + http_request, http_response, "GET", 200, body={"status": "Succeeded"} + ).http_response + else: + pytest.fail("No other query allowed") - def send(request, **kwargs): - assert request.method == 'GET' + client = pipeline_client_builder(send) - if request.url == 'http://example.org/location': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'location_result': True} - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded'} - ).http_response - else: - pytest.fail("No other query allowed") + # LRO options with Location final state + poll = LROPoller(client, initial_response, deserialization_cb, LROBasePolling(0)) + result = poll.result() + assert result["location_result"] == True - client = pipeline_client_builder(send) + # Location has no body - # LRO options with Location final state - poll = LROPoller( - client, - initial_response, - deserialization_cb, - LROBasePolling(0)) - result = poll.result() - assert result['location_result'] == True + def send(request, **kwargs): + assert request.method == "GET" - # Location has no body + if request.url == "http://example.org/location": + response = TestBasePolling.mock_send(http_request, http_response, "GET", 200, body=None).http_response + return response + elif request.url == "http://example.org/async_monitor": + return TestBasePolling.mock_send( + http_request, http_response, "GET", 200, body={"status": "Succeeded"} + ).http_response + else: + pytest.fail("No other query allowed") - def send(request, **kwargs): - assert request.method == 'GET' - - if request.url == 'http://example.org/location': - response = TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body=None - ).http_response - return response - elif request.url == 'http://example.org/async_monitor': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded'} - ).http_response - else: - pytest.fail("No other query allowed") + client = pipeline_client_builder(send) - client = pipeline_client_builder(send) + poll = LROPoller(client, initial_response, deserialization_cb, LROBasePolling(0)) + result = poll.result() + assert result is None - poll = LROPoller( - client, - initial_response, - deserialization_cb, - LROBasePolling(0)) - result = poll.result() - assert result is None @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) def test_post_resource_location(pipeline_client_builder, deserialization_cb, http_request, http_response): - # ResourceLocation - - # The initial response contains both Location and Operation-Location, a 202 and no Body - initial_response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', - 202, - { - 'operation-location': 'http://example.org/async_monitor', - }, - '' - ) - - def send(request, **kwargs): - assert request.method == 'GET' + # ResourceLocation + + # The initial response contains both Location and Operation-Location, a 202 and no Body + initial_response = TestBasePolling.mock_send( + http_request, + http_response, + "POST", + 202, + { + "operation-location": "http://example.org/async_monitor", + }, + "", + ) - if request.url == 'http://example.org/resource_location': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'location_result': True} - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded', 'resourceLocation': 'http://example.org/resource_location'} - ).http_response - else: - pytest.fail("No other query allowed") + def send(request, **kwargs): + assert request.method == "GET" + + if request.url == "http://example.org/resource_location": + return TestBasePolling.mock_send( + http_request, http_response, "GET", 200, body={"location_result": True} + ).http_response + elif request.url == "http://example.org/async_monitor": + return TestBasePolling.mock_send( + http_request, + http_response, + "GET", + 200, + body={"status": "Succeeded", "resourceLocation": "http://example.org/resource_location"}, + ).http_response + else: + pytest.fail("No other query allowed") - client = pipeline_client_builder(send) + client = pipeline_client_builder(send) - poll = LROPoller( - client, - initial_response, - deserialization_cb, - LROBasePolling(0)) - result = poll.result() - assert result['location_result'] == True + poll = LROPoller(client, initial_response, deserialization_cb, LROBasePolling(0)) + result = poll.result() + assert result["location_result"] == True class TestBasePolling(object): - convert = re.compile('([a-z0-9])([A-Z])') + convert = re.compile("([a-z0-9])([A-Z])") @staticmethod def mock_send(http_request, http_response, method, status, headers=None, body=RESPONSE_BODY): @@ -333,13 +315,11 @@ def mock_send(http_request, http_response, method, status, headers=None, body=RE headers = {} response = Response() response._content_consumed = True - response._content = json.dumps(body).encode('ascii') if body is not None else None + response._content = json.dumps(body).encode("ascii") if body is not None else None response.request = Request() response.request.method = method response.request.url = RESOURCE_URL - response.request.headers = { - 'x-ms-client-request-id': '67f4dd4e-6262-45e1-8bed-5c45cf23b6d9' - } + response.request.headers = {"x-ms-client-request-id": "67f4dd4e-6262-45e1-8bed-5c45cf23b6d9"} response.status_code = status response.headers = headers response.headers.update({"content-type": "application/json; charset=utf8"}) @@ -360,25 +340,21 @@ def mock_send(http_request, http_response, method, status, headers=None, body=RE response.request.headers, body, None, # form_content - None # stream_content + None, # stream_content ) response = create_transport_response( http_response, request, response, ) - return PipelineResponse( - request, - response, - None # context - ) + return PipelineResponse(request, response, None) # context @staticmethod def mock_update(http_request, http_response, url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) - response.request.method = 'GET' + response.request.method = "GET" response.headers = headers or {} response.headers.update({"content-type": "application/json; charset=utf8"}) response.reason = "OK" @@ -386,13 +362,13 @@ def mock_update(http_request, http_response, url, headers=None): if url == ASYNC_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = ASYNC_BODY.encode('ascii') + response._content = ASYNC_BODY.encode("ascii") response.randomFieldFromPollAsyncOpHeader = None elif url == LOCATION_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = LOCATION_BODY.encode('ascii') + response._content = LOCATION_BODY.encode("ascii") response.randomFieldFromPollLocationHeader = None elif url == ERROR: @@ -401,10 +377,10 @@ def mock_update(http_request, http_response, url, headers=None): elif url == RESOURCE_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = RESOURCE_BODY.encode('ascii') + response._content = RESOURCE_BODY.encode("ascii") else: - raise Exception('URL does not match') + raise Exception("URL does not match") request = http_request( response.request.method, response.request.url, @@ -414,11 +390,7 @@ def mock_update(http_request, http_response, url, headers=None): request, response, ) - return PipelineResponse( - request, - response, - None # context - ) + return PipelineResponse(request, response, None) # context @staticmethod def mock_outputs(pipeline_response): @@ -428,15 +400,13 @@ def mock_outputs(pipeline_response): except ValueError: raise DecodeError("Impossible to deserialize") - body = {TestBasePolling.convert.sub(r'\1_\2', k).lower(): v - for k, v in body.items()} - properties = body.setdefault('properties', {}) - if 'name' in body: - properties['name'] = body['name'] + body = {TestBasePolling.convert.sub(r"\1_\2", k).lower(): v for k, v in body.items()} + properties = body.setdefault("properties", {}) + if "name" in body: + properties["name"] = body["name"] if properties: - properties = {TestBasePolling.convert.sub(r'\1_\2', k).lower(): v - for k, v in properties.items()} - del body['properties'] + properties = {TestBasePolling.convert.sub(r"\1_\2", k).lower(): v for k, v in properties.items()} + del body["properties"] body.update(properties) resource = SimpleResource(**body) else: @@ -446,103 +416,67 @@ def mock_outputs(pipeline_response): @staticmethod def mock_deserialization_no_body(pipeline_response): - """Use this mock when you don't expect a return (last body irrelevant) - """ + """Use this mock when you don't expect a return (last body irrelevant)""" return None @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) def test_long_running_put(self, http_request, http_response): - #TODO: Test custom header field + # TODO: Test custom header field # Test throw on non LRO related status code - response = TestBasePolling.mock_send( - http_request, - http_response, 'PUT', 1000, {}) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 1000, {}) CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response with pytest.raises(HttpResponseError): - LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)).result() + LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)).result() # Test with no polling necessary - response_body = { - 'properties':{'provisioningState': 'Succeeded'}, - 'name': TEST_NAME - } - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {}, response_body - ) + response_body = {"properties": {"provisioningState": "Succeeded"}, "name": TEST_NAME} + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {}, response_body) + def no_update_allowed(url, headers=None): raise ValueError("Should not try to update") - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0) - ) + + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert not hasattr(poll._polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(poll._polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling from operation-location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'operation-location': ASYNC_URL}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {"operation-location": ASYNC_URL}) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert not hasattr(poll._polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(poll._polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'location': LOCATION_URL}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {"location": LOCATION_URL}) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test polling initial payload invalid (SQLDb) response_body = {} # Empty will raise response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'location': LOCATION_URL}, response_body) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + http_request, http_response, "PUT", 201, {"location": LOCATION_URL}, response_body + ) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test fail to poll from operation-location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'operation-location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {"operation-location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)).result() + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)).result() # Test fail to poll from location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PUT', 201, - {'location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "PUT", 201, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)).result() + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)).result() @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) def test_long_running_patch(self, http_request, http_response): @@ -552,93 +486,84 @@ def test_long_running_patch(self, http_request, http_response): response = TestBasePolling.mock_send( http_request, http_response, - 'PATCH', 202, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + "PATCH", + 202, + {"location": LOCATION_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, http_response, - 'PATCH', 202, - {'operation-location': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + "PATCH", + 202, + {"operation-location": ASYNC_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert not hasattr(poll._polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(poll._polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling from location header response = TestBasePolling.mock_send( http_request, http_response, - 'PATCH', 200, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + "PATCH", + 200, + {"location": LOCATION_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, http_response, - 'PATCH', 200, - {'operation-location': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + "PATCH", + 200, + {"operation-location": ASYNC_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert not hasattr(poll._polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(poll._polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test fail to poll from operation-location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PATCH', 202, - {'operation-location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "PATCH", 202, {"operation-location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)).result() + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)).result() # Test fail to poll from location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'PATCH', 202, - {'location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "PATCH", 202, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)).result() + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)).result() @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) def test_long_running_delete(self, http_request, http_response): # Test polling from operation-location header response = TestBasePolling.mock_send( - http_request, - http_response, - 'DELETE', 202, - {'operation-location': ASYNC_URL}, - body="" + http_request, http_response, "DELETE", 202, {"operation-location": ASYNC_URL}, body="" ) CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_deserialization_no_body, - LROBasePolling(0)) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, LROBasePolling(0)) poll.wait() - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader + is None + ) @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) def test_long_running_post_legacy(self, http_request, http_response): @@ -648,64 +573,61 @@ def test_long_running_post_legacy(self, http_request, http_response): response = TestBasePolling.mock_send( http_request, http_response, - 'POST', 201, - {'operation-location': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "POST", + 201, + {"operation-location": ASYNC_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_deserialization_no_body, - LROBasePolling(0)) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, LROBasePolling(0)) poll.wait() - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader + is None + ) # Test polling from operation-location header response = TestBasePolling.mock_send( http_request, http_response, - 'POST', 202, - {'operation-location': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_deserialization_no_body, - LROBasePolling(0)) + "POST", + 202, + {"operation-location": ASYNC_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, LROBasePolling(0)) poll.wait() - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader + is None + ) # Test polling from location header response = TestBasePolling.mock_send( http_request, http_response, - 'POST', 202, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + "POST", + 202, + {"location": LOCATION_URL}, + body={"properties": {"provisioningState": "Succeeded"}}, + ) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test fail to poll from operation-location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'operation-location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"operation-location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)).result() + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)).result() # Test fail to poll from location header - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'location': ERROR}) + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)).result() + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)).result() @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) def test_long_running_negative(self, http_request, http_response): @@ -714,48 +636,27 @@ def test_long_running_negative(self, http_request, http_response): CLIENT.http_request_type = http_request CLIENT.http_response_type = http_response # Test LRO PUT throws for invalid json - LOCATION_BODY = '{' - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'location': LOCATION_URL}) - poll = LROPoller( - CLIENT, - response, - TestBasePolling.mock_outputs, - LROBasePolling(0) - ) + LOCATION_BODY = "{" + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"location": LOCATION_URL}) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) with pytest.raises(DecodeError): poll.result() - LOCATION_BODY = '{\'"}' - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'location': LOCATION_URL}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) + LOCATION_BODY = "{'\"}" + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"location": LOCATION_URL}) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) with pytest.raises(DecodeError): poll.result() - LOCATION_BODY = '{' + LOCATION_BODY = "{" POLLING_STATUS = 203 - response = TestBasePolling.mock_send( - http_request, - http_response, - 'POST', 202, - {'location': LOCATION_URL}) - poll = LROPoller(CLIENT, response, - TestBasePolling.mock_outputs, - LROBasePolling(0)) - with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization + response = TestBasePolling.mock_send(http_request, http_response, "POST", 202, {"location": LOCATION_URL}) + poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0)) + with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization poll.result() - assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode('ascii') + assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii") - LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) + LOCATION_BODY = json.dumps({"name": TEST_NAME}) POLLING_STATUS = 200 @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(REQUESTS_TRANSPORT_RESPONSES)) @@ -767,33 +668,25 @@ def test_post_final_state_via(self, pipeline_client_builder, deserialization_cb, initial_response = TestBasePolling.mock_send( http_request, http_response, - 'POST', + "POST", 202, { - 'location': 'http://example.org/location', - 'operation-location': 'http://example.org/async_monitor', + "location": "http://example.org/location", + "operation-location": "http://example.org/async_monitor", }, - '' + "", ) def send(request, **kwargs): - assert request.method == 'GET' + assert request.method == "GET" - if request.url == 'http://example.org/location': + if request.url == "http://example.org/location": return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'location_result': True} + http_request, http_response, "GET", 200, body={"location_result": True} ).http_response - elif request.url == 'http://example.org/async_monitor': + elif request.url == "http://example.org/async_monitor": return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded'} + http_request, http_response, "GET", 200, body={"status": "Succeeded"} ).http_response else: pytest.fail("No other query allowed") @@ -802,51 +695,36 @@ def send(request, **kwargs): # Test 1, LRO options with Location final state poll = LROPoller( - client, - initial_response, - deserialization_cb, - LROBasePolling(0, lro_options={"final-state-via": "location"})) + client, initial_response, deserialization_cb, LROBasePolling(0, lro_options={"final-state-via": "location"}) + ) result = poll.result() - assert result['location_result'] == True + assert result["location_result"] == True # Test 2, LRO options with Operation-Location final state poll = LROPoller( client, initial_response, deserialization_cb, - LROBasePolling(0, lro_options={"final-state-via": "operation-location"})) + LROBasePolling(0, lro_options={"final-state-via": "operation-location"}), + ) result = poll.result() - assert result['status'] == 'Succeeded' + assert result["status"] == "Succeeded" # Test 3, "do the right thing" and use Location by default - poll = LROPoller( - client, - initial_response, - deserialization_cb, - LROBasePolling(0)) + poll = LROPoller(client, initial_response, deserialization_cb, LROBasePolling(0)) result = poll.result() - assert result['location_result'] == True + assert result["location_result"] == True # Test 4, location has no body def send(request, **kwargs): - assert request.method == 'GET' + assert request.method == "GET" - if request.url == 'http://example.org/location': - return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body=None - ).http_response - elif request.url == 'http://example.org/async_monitor': + if request.url == "http://example.org/location": + return TestBasePolling.mock_send(http_request, http_response, "GET", 200, body=None).http_response + elif request.url == "http://example.org/async_monitor": return TestBasePolling.mock_send( - http_request, - http_response, - 'GET', - 200, - body={'status': 'Succeeded'} + http_request, http_response, "GET", 200, body={"status": "Succeeded"} ).http_response else: pytest.fail("No other query allowed") @@ -854,13 +732,12 @@ def send(request, **kwargs): client = pipeline_client_builder(send) poll = LROPoller( - client, - initial_response, - deserialization_cb, - LROBasePolling(0, lro_options={"final-state-via": "location"})) + client, initial_response, deserialization_cb, LROBasePolling(0, lro_options={"final-state-via": "location"}) + ) result = poll.result() assert result is None + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_final_get_via_location(port, http_request, deserialization_cb): client = TestRestClient(port) @@ -879,6 +756,7 @@ def test_final_get_via_location(port, http_request, deserialization_cb): result = poller.result() assert result == {"returnedFrom": "locationHeaderUrl"} + # THIS TEST WILL BE REMOVED SOON """Weird test, but we are temporarily adding back the POST check in OperationResourcePolling get_final_get_url. With the test added back, we should not exit on final state via checks and @@ -886,6 +764,8 @@ def test_final_get_via_location(port, http_request, deserialization_cb): and since I don't want to bother with adding a pipeline response object, just check that we get past the final state via checks """ + + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_post_check_patch(http_request): algorithm = OperationResourcePolling(lro_options={"final-state-via": "azure-async-operation"}) diff --git a/sdk/core/azure-core/tests/test_basic_transport.py b/sdk/core/azure-core/tests/test_basic_transport.py index 46a157fbce14..1a316a78a13b 100644 --- a/sdk/core/azure-core/tests/test_basic_transport.py +++ b/sdk/core/azure-core/tests/test_basic_transport.py @@ -19,7 +19,12 @@ from azure.core.exceptions import HttpResponseError import logging import pytest -from utils import HTTP_REQUESTS, request_and_responses_product, HTTP_CLIENT_TRANSPORT_RESPONSES, create_transport_response +from utils import ( + HTTP_REQUESTS, + request_and_responses_product, + HTTP_CLIENT_TRANSPORT_RESPONSES, + create_transport_response, +) from azure.core.rest._http_response_impl import HttpResponseImpl as RestHttpResponseImpl from azure.core.pipeline._tools import is_rest @@ -33,6 +38,7 @@ def __init__(self, request, body, content_type): def body(self): return self._body + class RestMockResponse(RestHttpResponseImpl): def __init__(self, request, body, content_type): super(RestMockResponse, self).__init__( @@ -56,6 +62,7 @@ def body(self): def content(self): return self._body + MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] @@ -66,9 +73,9 @@ def test_http_request_serialization(http_request): serialized = request.serialize() expected = ( - b'DELETE /container0/blob0 HTTP/1.1\r\n' + b"DELETE /container0/blob0 HTTP/1.1\r\n" # No headers - b'\r\n' + b"\r\n" ) assert serialized == expected @@ -77,24 +84,25 @@ def test_http_request_serialization(http_request): "DELETE", "/container0/blob0", # Use OrderedDict to get consistent test result on 3.5 where order is not guaranteed - headers=OrderedDict({ - "x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT", - "Authorization": "SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=", # fake key suppressed in credscan - "Content-Length": "0", - }) + headers=OrderedDict( + { + "x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT", + "Authorization": "SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=", # fake key suppressed in credscan + "Content-Length": "0", + } + ), ) serialized = request.serialize() expected = ( - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'Authorization: SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=\r\n' # fake key suppressed in credscan - b'Content-Length: 0\r\n' - b'\r\n' + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"Authorization: SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=\r\n" # fake key suppressed in credscan + b"Content-Length: 0\r\n" + b"\r\n" ) assert serialized == expected - # Method + Url + Headers + Body request = http_request( "DELETE", @@ -107,21 +115,21 @@ def test_http_request_serialization(http_request): serialized = request.serialize() expected = ( - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'Content-Length: 10\r\n' - b'\r\n' - b'I am groot' + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + b"I am groot" ) assert serialized == expected @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_url_join(http_request): - assert _urljoin('devstoreaccount1', '') == 'devstoreaccount1/' - assert _urljoin('devstoreaccount1', 'testdir/') == 'devstoreaccount1/testdir/' - assert _urljoin('devstoreaccount1/', '') == 'devstoreaccount1/' - assert _urljoin('devstoreaccount1/', 'testdir/') == 'devstoreaccount1/testdir/' + assert _urljoin("devstoreaccount1", "") == "devstoreaccount1/" + assert _urljoin("devstoreaccount1", "testdir/") == "devstoreaccount1/testdir/" + assert _urljoin("devstoreaccount1/", "") == "devstoreaccount1/" + assert _urljoin("devstoreaccount1/", "testdir/") == "devstoreaccount1/testdir/" @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_CLIENT_TRANSPORT_RESPONSES)) @@ -154,19 +162,16 @@ def test_response_deserialization(http_request): # Method + Url request = http_request("DELETE", "/container0/blob0") body = ( - b'HTTP/1.1 202 Accepted\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' + b"HTTP/1.1 202 Accepted\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" ) response = _deserialize_response(body, request) assert response.status_code == 202 assert response.reason == "Accepted" - assert response.headers == { - 'x-ms-request-id': '778fdc83-801e-0000-62ff-0334671e284f', - 'x-ms-version': '2018-11-09' - } + assert response.headers == {"x-ms-request-id": "778fdc83-801e-0000-62ff-0334671e284f", "x-ms-version": "2018-11-09"} # Method + Url + Headers + Body request = http_request( @@ -178,41 +183,39 @@ def test_response_deserialization(http_request): ) request.set_bytes_body(b"I am groot") body = ( - b'HTTP/1.1 200 OK\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'I am groot' + b"HTTP/1.1 200 OK\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"I am groot" ) response = _deserialize_response(body, request) assert isinstance(response.status_code, int) assert response.reason == "OK" - assert response.headers == { - 'x-ms-request-id': '778fdc83-801e-0000-62ff-0334671e284f', - 'x-ms-version': '2018-11-09' - } + assert response.headers == {"x-ms-request-id": "778fdc83-801e-0000-62ff-0334671e284f", "x-ms-version": "2018-11-09"} assert response.text() == "I am groot" + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_response_deserialization_utf8_bom(http_request): request = http_request("DELETE", "/container0/blob0") body = ( - b'HTTP/1.1 400 One of the request inputs is not valid.\r\n' - b'x-ms-error-code: InvalidInput\r\n' - b'x-ms-request-id: 5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\r\n' - b'x-ms-version: 2019-02-02\r\n' - b'Content-Length: 220\r\n' - b'Content-Type: application/xml\r\n' - b'Server: Windows-Azure-Blob/1.0\r\n' - b'\r\n' + b"HTTP/1.1 400 One of the request inputs is not valid.\r\n" + b"x-ms-error-code: InvalidInput\r\n" + b"x-ms-request-id: 5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\r\n" + b"x-ms-version: 2019-02-02\r\n" + b"Content-Length: 220\r\n" + b"Content-Type: application/xml\r\n" + b"Server: Windows-Azure-Blob/1.0\r\n" + b"\r\n" b'\xef\xbb\xbf\nInvalidInputOne' - b'of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z' + b"of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z" ) response = _deserialize_response(body, request) - assert response.body().startswith(b'\xef\xbb\xbf') + assert response.body().startswith(b"\xef\xbb\xbf") @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @@ -220,9 +223,7 @@ def test_multipart_send(http_request): transport = mock.MagicMock(spec=HttpTransport) - header_policy = HeadersPolicy({ - 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' - }) + header_policy = HeadersPolicy({"x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT"}) req0 = http_request("DELETE", "/container0/blob0") req1 = http_request("DELETE", "/container1/blob1") @@ -232,32 +233,32 @@ def test_multipart_send(http_request): req0, req1, policies=[header_policy], - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" # Fix it so test are deterministic + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", # Fix it so test are deterministic ) with Pipeline(transport) as pipeline: pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -265,9 +266,7 @@ def test_multipart_send(http_request): def test_multipart_send_with_context(http_request): transport = mock.MagicMock(spec=HttpTransport) - header_policy = HeadersPolicy({ - 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' - }) + header_policy = HeadersPolicy({"x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT"}) req0 = http_request("DELETE", "/container0/blob0") req1 = http_request("DELETE", "/container1/blob1") @@ -277,35 +276,35 @@ def test_multipart_send_with_context(http_request): req0, req1, policies=[header_policy], - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", # Fix it so test are deterministic - headers={'Accept': 'application/json'} + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", # Fix it so test are deterministic + headers={"Accept": "application/json"}, ) with Pipeline(transport) as pipeline: pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'Accept: application/json\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'Accept: application/json\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"Accept: application/json\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"Accept: application/json\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -314,20 +313,13 @@ def test_multipart_send_with_one_changeset(http_request): transport = mock.MagicMock(spec=HttpTransport) - header_policy = HeadersPolicy({ - 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' - }) + header_policy = HeadersPolicy({"x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT"}) - requests = [ - http_request("DELETE", "/container0/blob0"), - http_request("DELETE", "/container1/blob1") - ] + requests = [http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1")] changeset = http_request("", "") changeset.set_multipart_mixed( - *requests, - policies=[header_policy], - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + *requests, policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") @@ -340,30 +332,30 @@ def test_multipart_send_with_one_changeset(http_request): pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -372,23 +364,21 @@ def test_multipart_send_with_multiple_changesets(http_request): transport = mock.MagicMock(spec=HttpTransport) - header_policy = HeadersPolicy({ - 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' - }) + header_policy = HeadersPolicy({"x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT"}) changeset1 = http_request("", "") changeset1.set_multipart_mixed( http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1"), policies=[header_policy], - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) changeset2 = http_request("", "") changeset2.set_multipart_mixed( http_request("DELETE", "/container2/blob2"), http_request("DELETE", "/container3/blob3"), policies=[header_policy], - boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314" + boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314", ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") @@ -403,53 +393,53 @@ def test_multipart_send_with_multiple_changesets(http_request): pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'DELETE /container2/blob2 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 3\r\n' - b'\r\n' - b'DELETE /container3/blob3 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"DELETE /container2/blob2 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 3\r\n" + b"\r\n" + b"DELETE /container3/blob3 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -458,62 +448,60 @@ def test_multipart_send_with_combination_changeset_first(http_request): transport = mock.MagicMock(spec=HttpTransport) - header_policy = HeadersPolicy({ - 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' - }) + header_policy = HeadersPolicy({"x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT"}) changeset = http_request("", "") changeset.set_multipart_mixed( http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1"), policies=[header_policy], - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, http_request("DELETE", "/container2/blob2"), policies=[header_policy], - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) with Pipeline(transport) as pipeline: pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'DELETE /container2/blob2 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"DELETE /container2/blob2 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -522,62 +510,60 @@ def test_multipart_send_with_combination_changeset_last(http_request): transport = mock.MagicMock(spec=HttpTransport) - header_policy = HeadersPolicy({ - 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' - }) + header_policy = HeadersPolicy({"x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT"}) changeset = http_request("", "") changeset.set_multipart_mixed( http_request("DELETE", "/container1/blob1"), http_request("DELETE", "/container2/blob2"), policies=[header_policy], - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( http_request("DELETE", "/container0/blob0"), changeset, policies=[header_policy], - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) with Pipeline(transport) as pipeline: pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'DELETE /container2/blob2 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"DELETE /container2/blob2 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @@ -586,15 +572,13 @@ def test_multipart_send_with_combination_changeset_middle(http_request): transport = mock.MagicMock(spec=HttpTransport) - header_policy = HeadersPolicy({ - 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' - }) + header_policy = HeadersPolicy({"x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT"}) changeset = http_request("", "") changeset.set_multipart_mixed( http_request("DELETE", "/container1/blob1"), policies=[header_policy], - boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( @@ -602,66 +586,61 @@ def test_multipart_send_with_combination_changeset_middle(http_request): changeset, http_request("DELETE", "/container2/blob2"), policies=[header_policy], - boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", ) with Pipeline(transport) as pipeline: pipeline.run(request) assert request.body == ( - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'DELETE /container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'DELETE /container1/blob1 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'DELETE /container2/blob2 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'\r\n' - b'\r\n' - b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"DELETE /container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: multipart/mixed; boundary=changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"DELETE /container1/blob1 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"DELETE /container2/blob2 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"\r\n" + b"\r\n" + b"--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" ) @pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) def test_multipart_receive(http_request, mock_response): - class ResponsePolicy(object): def on_response(self, request, response): # type: (PipelineRequest, PipelineResponse) -> None - response.http_response.headers['x-ms-fun'] = 'true' + response.http_response.headers["x-ms-fun"] = "true" req0 = http_request("DELETE", "/container0/blob0") req1 = http_request("DELETE", "/container1/blob1") request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed( - req0, - req1, - policies=[ResponsePolicy()] - ) + request.set_multipart_mixed(req0, req1, policies=[ResponsePolicy()]) body_as_str = ( "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" @@ -692,8 +671,8 @@ def on_response(self, request, response): response = mock_response( request, - body_as_str.encode('ascii'), - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + body_as_str.encode("ascii"), + "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed", ) response = response.parts() @@ -702,11 +681,12 @@ def on_response(self, request, response): res0 = response[0] assert res0.status_code == 202 - assert res0.headers['x-ms-fun'] == 'true' + assert res0.headers["x-ms-fun"] == "true" res1 = response[1] assert res1.status_code == 404 - assert res1.headers['x-ms-fun'] == 'true' + assert res1.headers["x-ms-fun"] == "true" + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) def test_raise_for_status_bad_response(mock_response): @@ -715,6 +695,7 @@ def test_raise_for_status_bad_response(mock_response): with pytest.raises(HttpResponseError): response.raise_for_status() + @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) def test_raise_for_status_good_response(mock_response): response = mock_response(request=None, body=None, content_type=None) @@ -727,46 +708,43 @@ def test_multipart_receive_with_one_changeset(http_request, mock_response): changeset = http_request(None, None) changeset.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - http_request("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1") ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 202 Accepted\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'HTTP/1.1 202 Accepted\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 202 Accepted\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"HTTP/1.1 202 Accepted\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -783,75 +761,71 @@ def test_multipart_receive_with_multiple_changesets(http_request, mock_response) changeset1 = http_request(None, None) changeset1.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - http_request("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1") ) changeset2 = http_request(None, None) changeset2.set_multipart_mixed( - http_request("DELETE", "/container2/blob2"), - http_request("DELETE", "/container3/blob3") + http_request("DELETE", "/container2/blob2"), http_request("DELETE", "/container3/blob3") ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset1, changeset2) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 200\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'HTTP/1.1 202\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 200\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"HTTP/1.1 202\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314"\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 404\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 3\r\n' - b'\r\n' - b'HTTP/1.1 409\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 404\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 3\r\n" + b"\r\n" + b"HTTP/1.1 409\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -869,55 +843,52 @@ def test_multipart_receive_with_combination_changeset_first(http_request, mock_r changeset = http_request(None, None) changeset.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - http_request("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), http_request("DELETE", "/container1/blob1") ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset, http_request("DELETE", "/container2/blob2")) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 200\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'HTTP/1.1 202\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 404\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 200\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"HTTP/1.1 202\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 404\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -937,53 +908,49 @@ def test_multipart_receive_with_combination_changeset_middle(http_request, mock_ request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - http_request("DELETE", "/container0/blob0"), - changeset, - http_request("DELETE", "/container2/blob2") + http_request("DELETE", "/container0/blob0"), changeset, http_request("DELETE", "/container2/blob2") ) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 200\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 200\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 202\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 404\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 202\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 404\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -1000,56 +967,53 @@ def test_multipart_receive_with_combination_changeset_last(http_request, mock_re changeset = http_request(None, None) changeset.set_multipart_mixed( - http_request("DELETE", "/container1/blob1"), - http_request("DELETE", "/container2/blob2") + http_request("DELETE", "/container1/blob1"), http_request("DELETE", "/container2/blob2") ) request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(http_request("DELETE", "/container0/blob0"), changeset) body_as_bytes = ( - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 2\r\n' - b'\r\n' - b'HTTP/1.1 200\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 2\r\n" + b"\r\n" + b"HTTP/1.1 200\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 0\r\n' - b'\r\n' - b'HTTP/1.1 202\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' - b'Content-Type: application/http\r\n' - b'Content-Transfer-Encoding: binary\r\n' - b'Content-ID: 1\r\n' - b'\r\n' - b'HTTP/1.1 404\r\n' - b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' - b'x-ms-version: 2018-11-09\r\n' - b'\r\n' - b'\r\n' - b'--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' - b'\r\n' - b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n' + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 0\r\n" + b"\r\n" + b"HTTP/1.1 202\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n" + b"Content-Type: application/http\r\n" + b"Content-Transfer-Encoding: binary\r\n" + b"Content-ID: 1\r\n" + b"\r\n" + b"HTTP/1.1 404\r\n" + b"x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + b"x-ms-version: 2018-11-09\r\n" + b"\r\n" + b"\r\n" + b"--changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n" + b"\r\n" + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--\r\n" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) parts = [] @@ -1073,21 +1037,19 @@ def test_multipart_receive_with_bom(http_request, mock_response): b"Content-Type: application/http\n" b"Content-Transfer-Encoding: binary\n" b"Content-ID: 0\n" - b'\r\n' - b'HTTP/1.1 400 One of the request inputs is not valid.\r\n' - b'Content-Length: 220\r\n' - b'Content-Type: application/xml\r\n' - b'Server: Windows-Azure-Blob/1.0\r\n' - b'\r\n' + b"\r\n" + b"HTTP/1.1 400 One of the request inputs is not valid.\r\n" + b"Content-Length: 220\r\n" + b"Content-Type: application/xml\r\n" + b"Server: Windows-Azure-Blob/1.0\r\n" + b"\r\n" b'\xef\xbb\xbf\nInvalidInputOne' - b'of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z\n' + b"of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z\n" b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) response = mock_response( - request, - body_as_bytes, - "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + request, body_as_bytes, "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" ) response = response.parts() @@ -1095,7 +1057,7 @@ def test_multipart_receive_with_bom(http_request, mock_response): res0 = response[0] assert res0.status_code == 400 - assert res0.body().startswith(b'\xef\xbb\xbf') + assert res0.body().startswith(b"\xef\xbb\xbf") @pytest.mark.parametrize("http_request,mock_response", request_and_responses_product(MOCK_RESPONSES)) @@ -1132,8 +1094,8 @@ def test_recursive_multipart_receive(http_request, mock_response): response = mock_response( request, - body_as_str.encode('ascii'), - "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6" + body_as_str.encode("ascii"), + "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6", ) response = response.parts() @@ -1188,4 +1150,4 @@ def test_conflict_timeout(caplog, port, http_request): with pytest.raises(ValueError): with Pipeline(transport) as pipeline: - pipeline.run(request, connection_timeout=(100, 100), read_timeout = 100) + pipeline.run(request, connection_timeout=(100, 100), read_timeout=100) diff --git a/sdk/core/azure-core/tests/test_connection_string_parsing.py b/sdk/core/azure-core/tests/test_connection_string_parsing.py index 1a3957e5b35d..713d973e8fd2 100644 --- a/sdk/core/azure-core/tests/test_connection_string_parsing.py +++ b/sdk/core/azure-core/tests/test_connection_string_parsing.py @@ -3,14 +3,15 @@ from devtools_testutils import AzureMgmtTestCase + class CoreConnectionStringParserTests(AzureMgmtTestCase): # cSpell:disable def test_parsing_with_case_sensitive_keys_for_sensitive_conn_str(self, **kwargs): - conn_str = 'Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, True) - assert parse_result["Endpoint"] == 'XXXXENDPOINTXXXX' - assert parse_result["SharedAccessKeyName"] == 'XXXXPOLICYXXXX' - assert parse_result["SharedAccessKey"] == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result["Endpoint"] == "XXXXENDPOINTXXXX" + assert parse_result["SharedAccessKeyName"] == "XXXXPOLICYXXXX" + assert parse_result["SharedAccessKey"] == "THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" with pytest.raises(KeyError): parse_result["endPoint"] with pytest.raises(KeyError): @@ -19,37 +20,37 @@ def test_parsing_with_case_sensitive_keys_for_sensitive_conn_str(self, **kwargs) parse_result["sharedaccesskey"] def test_parsing_with_case_insensitive_keys_for_sensitive_conn_str(self, **kwargs): - conn_str = 'Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, False) - assert parse_result["endpoint"] == 'XXXXENDPOINTXXXX' - assert parse_result["sharedaccesskeyname"] == 'XXXXPOLICYXXXX' - assert parse_result["sharedaccesskey"] == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result["endpoint"] == "XXXXENDPOINTXXXX" + assert parse_result["sharedaccesskeyname"] == "XXXXPOLICYXXXX" + assert parse_result["sharedaccesskey"] == "THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" def test_parsing_with_case_insensitive_keys_for_insensitive_conn_str(self, **kwargs): - conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, False) - assert parse_result["endpoint"] == 'XXXXENDPOINTXXXX' - assert parse_result["sharedaccesskeyname"] == 'XXXXPOLICYXXXX' - assert parse_result["sharedaccesskey"] == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result["endpoint"] == "XXXXENDPOINTXXXX" + assert parse_result["sharedaccesskeyname"] == "XXXXPOLICYXXXX" + assert parse_result["sharedaccesskey"] == "THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" def test_error_with_duplicate_case_sensitive_keys_for_sensitive_conn_str(self, **kwargs): - conn_str = 'Endpoint=XXXXENDPOINTXXXX;Endpoint=XXXXENDPOINT2XXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "Endpoint=XXXXENDPOINTXXXX;Endpoint=XXXXENDPOINT2XXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" with pytest.raises(ValueError) as e: parse_result = parse_connection_string(conn_str, True) assert str(e.value) == "Connection string is either blank or malformed." def test_success_with_duplicate_case_sensitive_keys_for_sensitive_conn_str(self, **kwargs): - conn_str = 'enDpoInt=XXXXENDPOINTXXXX;Endpoint=XXXXENDPOINT2XXXX;' + conn_str = "enDpoInt=XXXXENDPOINTXXXX;Endpoint=XXXXENDPOINT2XXXX;" parse_result = parse_connection_string(conn_str, True) - assert parse_result["enDpoInt"] == 'XXXXENDPOINTXXXX' - assert parse_result["Endpoint"] == 'XXXXENDPOINT2XXXX' + assert parse_result["enDpoInt"] == "XXXXENDPOINTXXXX" + assert parse_result["Endpoint"] == "XXXXENDPOINT2XXXX" def test_error_with_duplicate_case_insensitive_keys_for_insensitive_conn_str(self, **kwargs): - conn_str = 'endPoinT=XXXXENDPOINTXXXX;eNdpOint=XXXXENDPOINT2XXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "endPoinT=XXXXENDPOINTXXXX;eNdpOint=XXXXENDPOINT2XXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" with pytest.raises(ValueError) as e: parse_result = parse_connection_string(conn_str, False) assert str(e.value) == "Duplicate key in connection string: endpoint" - + def test_error_with_malformed_conn_str(self): for conn_str in ["", "foobar", "foo;bar;baz", ";", "foo=;bar=;", "=", "=;=="]: with pytest.raises(ValueError) as e: @@ -57,34 +58,34 @@ def test_error_with_malformed_conn_str(self): self.assertEqual(str(e.value), "Connection string is either blank or malformed.") def test_case_insensitive_clear_method(self): - conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, False) parse_result.clear() assert len(parse_result) == 0 def test_case_insensitive_copy_method(self): - conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, False) copied = parse_result.copy() assert copied == parse_result - + def test_case_insensitive_get_method(self): - conn_str = 'Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, False) - assert parse_result.get("sharedaccesskeyname") == 'XXXXPOLICYXXXX' - assert parse_result.get("sharedaccesskey") == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result.get("sharedaccesskeyname") == "XXXXPOLICYXXXX" + assert parse_result.get("sharedaccesskey") == "THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" assert parse_result.get("accesskey") is None assert parse_result.get("accesskey", "XXothertestkeyXX=") == "XXothertestkeyXX=" def test_case_insensitive_keys_method(self): - conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, False) keys = parse_result.keys() assert len(keys) == 3 assert "endpoint" in keys - + def test_case_insensitive_pop_method(self): - conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, False) endpoint = parse_result.pop("endpoint") sharedaccesskey = parse_result.pop("sharedaccesskey") @@ -93,8 +94,8 @@ def test_case_insensitive_pop_method(self): assert sharedaccesskey == "THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" def test_case_insensitive_update_with_insensitive_method(self): - conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' - conn_str2 = 'hostName=XXXXENDPOINTXXXX;ACCessKEy=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' + conn_str = "enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" + conn_str2 = "hostName=XXXXENDPOINTXXXX;ACCessKEy=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;" parse_result_insensitive = parse_connection_string(conn_str, False) parse_result_insensitive2 = parse_connection_string(conn_str2, False) @@ -108,12 +109,12 @@ def test_case_insensitive_update_with_insensitive_method(self): parse_result_insensitive_dupe = parse_connection_string(conn_str_duplicate_key, False) parse_result_insensitive.update(parse_result_insensitive_dupe) assert parse_result_insensitive_dupe["endpoint"] == "XXXXENDPOINT2XXXX" - assert parse_result_insensitive_dupe["accesskey"] == "TestKey" + assert parse_result_insensitive_dupe["accesskey"] == "TestKey" assert len(parse_result_insensitive) == 5 def test_case_sensitive_update_with_insensitive_method(self): - conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' - conn_str2 = 'hostName=XXXXENDPOINTXXXX;ACCessKEy=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' + conn_str = "enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" + conn_str2 = "hostName=XXXXENDPOINTXXXX;ACCessKEy=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;" parse_result_insensitive = parse_connection_string(conn_str, False) parse_result_sensitive = parse_connection_string(conn_str2, True) @@ -124,8 +125,9 @@ def test_case_sensitive_update_with_insensitive_method(self): parse_result_sensitive["hostname"] def test_case_insensitive_values_method(self): - conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str = "enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" parse_result = parse_connection_string(conn_str, False) values = parse_result.values() assert len(values) == 3 - # cSpell:enable \ No newline at end of file + + # cSpell:enable diff --git a/sdk/core/azure-core/tests/test_custom_hook_policy.py b/sdk/core/azure-core/tests/test_custom_hook_policy.py index 8ff38e0c203f..402582d148a4 100644 --- a/sdk/core/azure-core/tests/test_custom_hook_policy.py +++ b/sdk/core/azure-core/tests/test_custom_hook_policy.py @@ -9,10 +9,11 @@ import mock from azure.core import PipelineClient from azure.core.pipeline.policies import CustomHookPolicy, UserAgentPolicy -from azure.core.pipeline.transport import HttpTransport +from azure.core.pipeline.transport import HttpTransport import pytest from utils import HTTP_REQUESTS + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_response_hook_policy_in_init(http_request): def test_callback(response): @@ -21,15 +22,13 @@ def test_callback(response): transport = mock.MagicMock(spec=HttpTransport) url = "http://localhost" custom_hook_policy = CustomHookPolicy(raw_response_hook=test_callback) - policies = [ - UserAgentPolicy("myuseragent"), - custom_hook_policy - ] + policies = [UserAgentPolicy("myuseragent"), custom_hook_policy] client = PipelineClient(base_url=url, policies=policies, transport=transport) request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_response_hook_policy_in_request(http_request): def test_callback(response): @@ -38,15 +37,13 @@ def test_callback(response): transport = mock.MagicMock(spec=HttpTransport) url = "http://localhost" custom_hook_policy = CustomHookPolicy() - policies = [ - UserAgentPolicy("myuseragent"), - custom_hook_policy - ] + policies = [UserAgentPolicy("myuseragent"), custom_hook_policy] client = PipelineClient(base_url=url, policies=policies, transport=transport) request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request, raw_response_hook=test_callback) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_response_hook_policy_in_both(http_request): def test_callback(response): @@ -58,15 +55,13 @@ def test_callback_request(response): transport = mock.MagicMock(spec=HttpTransport) url = "http://localhost" custom_hook_policy = CustomHookPolicy(raw_response_hook=test_callback) - policies = [ - UserAgentPolicy("myuseragent"), - custom_hook_policy - ] + policies = [UserAgentPolicy("myuseragent"), custom_hook_policy] client = PipelineClient(base_url=url, policies=policies, transport=transport) request = http_request("GET", url) with pytest.raises(TypeError): client._pipeline.run(request, raw_response_hook=test_callback_request) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_hook_policy_in_init(http_request): def test_callback(response): @@ -75,15 +70,13 @@ def test_callback(response): transport = mock.MagicMock(spec=HttpTransport) url = "http://localhost" custom_hook_policy = CustomHookPolicy(raw_request_hook=test_callback) - policies = [ - UserAgentPolicy("myuseragent"), - custom_hook_policy - ] + policies = [UserAgentPolicy("myuseragent"), custom_hook_policy] client = PipelineClient(base_url=url, policies=policies, transport=transport) request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_hook_policy_in_request(http_request): def test_callback(response): @@ -92,15 +85,13 @@ def test_callback(response): transport = mock.MagicMock(spec=HttpTransport) url = "http://localhost" custom_hook_policy = CustomHookPolicy() - policies = [ - UserAgentPolicy("myuseragent"), - custom_hook_policy - ] + policies = [UserAgentPolicy("myuseragent"), custom_hook_policy] client = PipelineClient(base_url=url, policies=policies, transport=transport) request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request, raw_request_hook=test_callback) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_hook_policy_in_both(http_request): def test_callback(response): @@ -112,10 +103,7 @@ def test_callback_request(response): transport = mock.MagicMock(spec=HttpTransport) url = "http://localhost" custom_hook_policy = CustomHookPolicy(raw_request_hook=test_callback) - policies = [ - UserAgentPolicy("myuseragent"), - custom_hook_policy - ] + policies = [UserAgentPolicy("myuseragent"), custom_hook_policy] client = PipelineClient(base_url=url, policies=policies, transport=transport) request = http_request("GET", url) with pytest.raises(TypeError): diff --git a/sdk/core/azure-core/tests/test_enums.py b/sdk/core/azure-core/tests/test_enums.py index 144444b8f51f..2c42d67b6dee 100644 --- a/sdk/core/azure-core/tests/test_enums.py +++ b/sdk/core/azure-core/tests/test_enums.py @@ -27,17 +27,18 @@ from azure.core import CaseInsensitiveEnumMeta + class MyCustomEnum(str, Enum, metaclass=CaseInsensitiveEnumMeta): - FOO = 'foo' - BAR = 'bar' + FOO = "foo" + BAR = "bar" def test_case_insensitive_enums(): - assert MyCustomEnum.foo.value == 'foo' - assert MyCustomEnum.FOO.value == 'foo' - assert MyCustomEnum('bar').value == 'bar' - assert 'bar' == MyCustomEnum.BAR - assert 'bar' == MyCustomEnum.bar - assert MyCustomEnum['foo'] == 'foo' - assert MyCustomEnum['FOO'] == 'foo' + assert MyCustomEnum.foo.value == "foo" + assert MyCustomEnum.FOO.value == "foo" + assert MyCustomEnum("bar").value == "bar" + assert "bar" == MyCustomEnum.BAR + assert "bar" == MyCustomEnum.bar + assert MyCustomEnum["foo"] == "foo" + assert MyCustomEnum["FOO"] == "foo" assert isinstance(MyCustomEnum.BAR, str) diff --git a/sdk/core/azure-core/tests/test_error_map.py b/sdk/core/azure-core/tests/test_error_map.py index 267ab1a5ee62..51422493d71e 100644 --- a/sdk/core/azure-core/tests/test_error_map.py +++ b/sdk/core/azure-core/tests/test_error_map.py @@ -32,36 +32,34 @@ ) from utils import request_and_responses_product, create_http_response, HTTP_RESPONSES + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_error_map(http_request, http_response): request = http_request("GET", "") response = create_http_response(http_response, request, None) - error_map = { - 404: ResourceNotFoundError - } + error_map = {404: ResourceNotFoundError} with pytest.raises(ResourceNotFoundError): map_error(404, response, error_map) + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_error_map_no_default(http_request, http_response): request = http_request("GET", "") response = create_http_response(http_response, request, None) - error_map = ErrorMap({ - 404: ResourceNotFoundError - }) + error_map = ErrorMap({404: ResourceNotFoundError}) with pytest.raises(ResourceNotFoundError): map_error(404, response, error_map) + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_error_map_with_default(http_request, http_response): request = http_request("GET", "") response = create_http_response(http_response, request, None) - error_map = ErrorMap({ - 404: ResourceNotFoundError - }, default_error=ResourceExistsError) + error_map = ErrorMap({404: ResourceNotFoundError}, default_error=ResourceExistsError) with pytest.raises(ResourceExistsError): map_error(401, response, error_map) + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_only_default(http_request, http_response): request = http_request("GET", "") diff --git a/sdk/core/azure-core/tests/test_exceptions.py b/sdk/core/azure-core/tests/test_exceptions.py index cb28b1bea121..71db780dae6d 100644 --- a/sdk/core/azure-core/tests/test_exceptions.py +++ b/sdk/core/azure-core/tests/test_exceptions.py @@ -29,17 +29,24 @@ from unittest.mock import Mock # module under test -from azure.core.exceptions import HttpResponseError, ODataV4Error, ODataV4Format, SerializationError, DeserializationError +from azure.core.exceptions import ( + HttpResponseError, + ODataV4Error, + ODataV4Format, + SerializationError, + DeserializationError, +) from azure.core.pipeline.transport import RequestsTransportResponse from azure.core.pipeline.transport._base import _HttpResponseBase as PipelineTransportHttpResponseBase from azure.core.rest._http_response_impl import _HttpResponseBaseImpl as RestHttpResponseBase from utils import HTTP_REQUESTS + class PipelineTransportMockResponse(PipelineTransportHttpResponseBase): def __init__(self, json_body): super(PipelineTransportMockResponse, self).__init__( request=None, - internal_response = None, + internal_response=None, ) self.status_code = 400 self.reason = "Bad Request" @@ -49,6 +56,7 @@ def __init__(self, json_body): def body(self): return self._body + class RestMockResponse(RestHttpResponseBase): def __init__(self, json_body): super(RestMockResponse, self).__init__( @@ -69,30 +77,28 @@ def body(self): def content(self): return self._body + MOCK_RESPONSES = [PipelineTransportMockResponse, RestMockResponse] -class FakeErrorOne(object): +class FakeErrorOne(object): def __init__(self): self.error = Mock(message="A fake error", code="FakeErrorOne") class FakeErrorTwo(object): - def __init__(self): self.code = "FakeErrorTwo" self.message = "A different fake error" class FakeHttpResponse(HttpResponseError): - def __init__(self, response, error, *args, **kwargs): self.error = error super(FakeHttpResponse, self).__init__(self, response=response, *args, **kwargs) class TestExceptions(object): - def test_empty_httpresponse_error(self): error = HttpResponseError() assert str(error) == "Operation returned an invalid status 'None'" @@ -112,14 +118,14 @@ def test_message_httpresponse_error(self): assert error.status_code is None def test_error_continuation_token(self): - error = HttpResponseError(message="Specific error message", continuation_token='foo') + error = HttpResponseError(message="Specific error message", continuation_token="foo") assert str(error) == "Specific error message" assert error.message == "Specific error message" assert error.response is None assert error.reason is None assert error.error is None assert error.status_code is None - assert error.continuation_token == 'foo' + assert error.continuation_token == "foo" @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) def test_deserialized_httpresponse_error_code(self, mock_response): @@ -151,7 +157,6 @@ def test_deserialized_httpresponse_error_code(self, mock_response): assert str(error) == "(FakeErrorOne) A fake error\nCode: FakeErrorOne\nMessage: A fake error" - @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) def test_deserialized_httpresponse_error_message(self, mock_response): """This is backward compat support for weird responses, adn even if it's likely @@ -183,7 +188,7 @@ def test_httpresponse_error_with_response(self, port, mock_response): error = HttpResponseError(response=http_response) assert error.message == "Operation returned an invalid status 'OK'" assert error.response is not None - assert error.reason == 'OK' + assert error.reason == "OK" assert isinstance(error.status_code, int) assert error.error is None @@ -194,15 +199,14 @@ def test_odata_v4_exception(self, mock_response): "code": "501", "message": "Unsupported functionality", "target": "query", - "details": [{ - "code": "301", - "target": "$search", - "message": "$search query option not supported", - }], - "innererror": { - "trace": [], - "context": {} - } + "details": [ + { + "code": "301", + "target": "$search", + "message": "$search query option not supported", + } + ], + "innererror": {"trace": [], "context": {}}, } } exp = ODataV4Error(mock_response(json.dumps(message).encode("utf-8"))) @@ -225,8 +229,7 @@ def test_odata_v4_exception(self, mock_response): @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) def test_odata_v4_minimal(self, mock_response): - """Minimal valid OData v4 is code/message and nothing else. - """ + """Minimal valid OData v4 is code/message and nothing else.""" message = { "error": { "code": "501", @@ -242,17 +245,14 @@ def test_odata_v4_minimal(self, mock_response): @pytest.mark.parametrize("mock_response", MOCK_RESPONSES) def test_broken_odata_details(self, mock_response): - """Do not block creating a nice exception if "details" only is broken - """ + """Do not block creating a nice exception if "details" only is broken""" message = { "error": { "code": "Conflict", "message": "The maximum number of Free ServerFarms allowed in a Subscription is 10.", "target": None, "details": [ - { - "message": "The maximum number of Free ServerFarms allowed in a Subscription is 10." - }, + {"message": "The maximum number of Free ServerFarms allowed in a Subscription is 10."}, {"code": "Conflict"}, { "errorentity": { @@ -291,7 +291,10 @@ def test_non_odatav4_error_body(self, client, http_request): response = client.send_request(request) with pytest.raises(HttpResponseError) as ex: response.raise_for_status() - assert str(ex.value) == "Operation returned an invalid status 'BAD REQUEST'\nContent: {\"code\": 400, \"error\": {\"global\": [\"MY-ERROR-MESSAGE-THAT-IS-COMING-FROM-THE-API\"]}}" + assert ( + str(ex.value) + == 'Operation returned an invalid status \'BAD REQUEST\'\nContent: {"code": 400, "error": {"global": ["MY-ERROR-MESSAGE-THAT-IS-COMING-FROM-THE-API"]}}' + ) @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_malformed_json(self, client, http_request): @@ -299,7 +302,10 @@ def test_malformed_json(self, client, http_request): response = client.send_request(request) with pytest.raises(HttpResponseError) as ex: response.raise_for_status() - assert str(ex.value) == "Operation returned an invalid status 'BAD REQUEST'\nContent: {\"code\": 400, \"error\": {\"global\": [\"MY-ERROR-MESSAGE-THAT-IS-COMING-FROM-THE-API\"]" + assert ( + str(ex.value) + == 'Operation returned an invalid status \'BAD REQUEST\'\nContent: {"code": 400, "error": {"global": ["MY-ERROR-MESSAGE-THAT-IS-COMING-FROM-THE-API"]' + ) @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_text(self, client, http_request): @@ -315,7 +321,8 @@ def test_datav4_error(self, client, http_request): response = client.send_request(request) with pytest.raises(HttpResponseError) as ex: response.raise_for_status() - assert "Content: {\"" not in str(ex.value) + assert 'Content: {"' not in str(ex.value) + def test_serialization_error(): message = "Oopsy bad input passed for serialization" @@ -328,6 +335,7 @@ def test_serialization_error(): raise error assert str(ex.value) == message + def test_deserialization_error(): message = "Oopsy bad input passed for serialization" error = DeserializationError(message) diff --git a/sdk/core/azure-core/tests/test_http_logging_policy.py b/sdk/core/azure-core/tests/test_http_logging_policy.py index d882bcdecfe8..dfc397fa1da6 100644 --- a/sdk/core/azure-core/tests/test_http_logging_policy.py +++ b/sdk/core/azure-core/tests/test_http_logging_policy.py @@ -6,32 +6,32 @@ import pytest import logging import types + try: from unittest.mock import Mock except ImportError: # python < 3.3 from mock import Mock # type: ignore -from azure.core.pipeline import ( - PipelineResponse, - PipelineRequest, - PipelineContext -) +from azure.core.pipeline import PipelineResponse, PipelineRequest, PipelineContext from azure.core.pipeline.policies import ( HttpLoggingPolicy, ) from utils import HTTP_RESPONSES, create_http_response, request_and_responses_product from azure.core.pipeline._tools import is_rest + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_http_logger(http_request, http_response): - class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() self.messages = [] + def reset(self): self.messages = [] + def emit(self, record): self.messages.append(record) + mock_handler = MockHandler() logger = logging.getLogger("testlogger") @@ -40,7 +40,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 http_response.headers["x-ms-error-code"] = "ERRORCODE" @@ -52,16 +52,16 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") assert messages_request[0] == "Request URL: 'http://localhost/'" assert messages_request[1] == "Request method: 'GET'" - assert messages_request[2] == 'Request headers:' - assert messages_request[3] == 'No body was attached to the request' - assert messages_response[0] == 'Response status: 202' - assert messages_response[1] == 'Response headers:' + assert messages_request[2] == "Request headers:" + assert messages_request[3] == "No body was attached to the request" + assert messages_response[0] == "Response status: 202" + assert messages_response[1] == "Response headers:" assert messages_response[2] == " 'x-ms-error-code': 'ERRORCODE'" mock_handler.reset() @@ -76,7 +76,7 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 4 messages_request1 = mock_handler.messages[0].message.split("\n") messages_response1 = mock_handler.messages[1].message.split("\n") @@ -84,22 +84,22 @@ def emit(self, record): messages_response2 = mock_handler.messages[3].message.split("\n") assert messages_request1[0] == "Request URL: 'http://localhost/'" assert messages_request1[1] == "Request method: 'GET'" - assert messages_request1[2] == 'Request headers:' - assert messages_request1[3] == 'No body was attached to the request' - assert messages_response1[0] == 'Response status: 202' - assert messages_response1[1] == 'Response headers:' + assert messages_request1[2] == "Request headers:" + assert messages_request1[3] == "No body was attached to the request" + assert messages_response1[0] == "Response status: 202" + assert messages_response1[1] == "Response headers:" assert messages_request2[0] == "Request URL: 'http://localhost/'" assert messages_request2[1] == "Request method: 'GET'" - assert messages_request2[2] == 'Request headers:' - assert messages_request2[3] == 'No body was attached to the request' - assert messages_response2[0] == 'Response status: 202' - assert messages_response2[1] == 'Response headers:' + assert messages_request2[2] == "Request headers:" + assert messages_request2[3] == "No body was attached to the request" + assert messages_response2[0] == "Response status: 202" + assert messages_response2[1] == "Response headers:" mock_handler.reset() # Headers and query parameters - policy.allowed_query_params = ['country'] + policy.allowed_query_params = ["country"] universal_request.headers = { "Accept": "Caramel", @@ -115,7 +115,7 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") @@ -123,40 +123,31 @@ def emit(self, record): assert messages_request[1] == "Request method: 'GET'" assert messages_request[2] == "Request headers:" # Dict not ordered in Python, exact logging order doesn't matter - assert set([ - messages_request[3], - messages_request[4] - ]) == set([ - " 'Accept': 'Caramel'", - " 'Hate': 'REDACTED'" - ]) - assert messages_request[5] == 'No body was attached to the request' + assert set([messages_request[3], messages_request[4]]) == set([" 'Accept': 'Caramel'", " 'Hate': 'REDACTED'"]) + assert messages_request[5] == "No body was attached to the request" assert messages_response[0] == "Response status: 202" assert messages_response[1] == "Response headers:" # Dict not ordered in Python, exact logging order doesn't matter - assert set([ - messages_response[2], - messages_response[3] - ]) == set([ - " 'Content-Type': 'Caramel'", - " 'HateToo': 'REDACTED'" - ]) + assert set([messages_response[2], messages_response[3]]) == set( + [" 'Content-Type': 'Caramel'", " 'HateToo': 'REDACTED'"] + ) mock_handler.reset() - @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_http_logger_operation_level(http_request, http_response): - class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() self.messages = [] + def reset(self): self.messages = [] + def emit(self, record): self.messages.append(record) + mock_handler = MockHandler() logger = logging.getLogger("testlogger") @@ -164,9 +155,9 @@ def emit(self, record): logger.setLevel(logging.DEBUG) policy = HttpLoggingPolicy() - kwargs={'logger': logger} + kwargs = {"logger": logger} - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None, **kwargs)) @@ -177,16 +168,16 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") assert messages_request[0] == "Request URL: 'http://localhost/'" assert messages_request[1] == "Request method: 'GET'" - assert messages_request[2] == 'Request headers:' - assert messages_request[3] == 'No body was attached to the request' - assert messages_response[0] == 'Response status: 202' - assert messages_response[1] == 'Response headers:' + assert messages_request[2] == "Request headers:" + assert messages_request[3] == "No body was attached to the request" + assert messages_response[0] == "Response status: 202" + assert messages_response[1] == "Response headers:" mock_handler.reset() @@ -202,7 +193,7 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 4 messages_request1 = mock_handler.messages[0].message.split("\n") messages_response1 = mock_handler.messages[1].message.split("\n") @@ -210,31 +201,33 @@ def emit(self, record): messages_response2 = mock_handler.messages[3].message.split("\n") assert messages_request1[0] == "Request URL: 'http://localhost/'" assert messages_request1[1] == "Request method: 'GET'" - assert messages_request1[2] == 'Request headers:' - assert messages_request1[3] == 'No body was attached to the request' - assert messages_response1[0] == 'Response status: 202' - assert messages_response1[1] == 'Response headers:' + assert messages_request1[2] == "Request headers:" + assert messages_request1[3] == "No body was attached to the request" + assert messages_response1[0] == "Response status: 202" + assert messages_response1[1] == "Response headers:" assert messages_request2[0] == "Request URL: 'http://localhost/'" assert messages_request2[1] == "Request method: 'GET'" - assert messages_request2[2] == 'Request headers:' - assert messages_request2[3] == 'No body was attached to the request' - assert messages_response2[0] == 'Response status: 202' - assert messages_response2[1] == 'Response headers:' + assert messages_request2[2] == "Request headers:" + assert messages_request2[3] == "No body was attached to the request" + assert messages_response2[0] == "Response status: 202" + assert messages_response2[1] == "Response headers:" mock_handler.reset() @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_http_logger_with_body(http_request, http_response): - class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() self.messages = [] + def reset(self): self.messages = [] + def emit(self, record): self.messages.append(record) + mock_handler = MockHandler() logger = logging.getLogger("testlogger") @@ -243,7 +236,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") universal_request.body = "testbody" http_response = create_http_response(http_response, universal_request, None) http_response.status_code = 202 @@ -253,31 +246,33 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") assert messages_request[0] == "Request URL: 'http://localhost/'" assert messages_request[1] == "Request method: 'GET'" - assert messages_request[2] == 'Request headers:' - assert messages_request[3] == 'A body is sent with the request' - assert messages_response[0] == 'Response status: 202' - assert messages_response[1] == 'Response headers:' + assert messages_request[2] == "Request headers:" + assert messages_request[3] == "A body is sent with the request" + assert messages_response[0] == "Response status: 202" + assert messages_response[1] == "Response headers:" mock_handler.reset() @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_http_logger_with_generator_body(http_request, http_response): - class MockHandler(logging.Handler): def __init__(self): super(MockHandler, self).__init__() self.messages = [] + def reset(self): self.messages = [] + def emit(self, record): self.messages.append(record) + mock_handler = MockHandler() logger = logging.getLogger("testlogger") @@ -286,7 +281,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") mock = Mock() mock.__class__ = types.GeneratorType universal_request.body = mock @@ -298,15 +293,15 @@ def emit(self, record): response = PipelineResponse(request, http_response, request.context) policy.on_response(request, response) - assert all(m.levelname == 'INFO' for m in mock_handler.messages) + assert all(m.levelname == "INFO" for m in mock_handler.messages) assert len(mock_handler.messages) == 2 messages_request = mock_handler.messages[0].message.split("\n") messages_response = mock_handler.messages[1].message.split("\n") assert messages_request[0] == "Request URL: 'http://localhost/'" assert messages_request[1] == "Request method: 'GET'" - assert messages_request[2] == 'Request headers:' - assert messages_request[3] == 'File upload' - assert messages_response[0] == 'Response status: 202' - assert messages_response[1] == 'Response headers:' + assert messages_request[2] == "Request headers:" + assert messages_request[3] == "File upload" + assert messages_response[0] == "Response status: 202" + assert messages_response[1] == "Response headers:" mock_handler.reset() diff --git a/sdk/core/azure-core/tests/test_messaging_cloud_event.py b/sdk/core/azure-core/tests/test_messaging_cloud_event.py index 23931f17c7a4..2148c6271908 100644 --- a/sdk/core/azure-core/tests/test_messaging_cloud_event.py +++ b/sdk/core/azure-core/tests/test_messaging_cloud_event.py @@ -11,6 +11,7 @@ from azure.core.utils._messaging_shared import _get_json_content from azure.core.serialization import NULL + class MockQueueMessage(object): def __init__(self, content=None): self.id = uuid.uuid4() @@ -21,31 +22,33 @@ def __init__(self, content=None): self.pop_receipt = None self.next_visible_on = None + class MockServiceBusReceivedMessage(object): def __init__(self, body=None, **kwargs): - self.body=body - self.application_properties=None - self.session_id=None - self.message_id='3f6c5441-5be5-4f33-80c3-3ffeb6a090ce' - self.content_type='application/cloudevents+json; charset=utf-8' - self.correlation_id=None - self.to=None - self.reply_to=None - self.reply_to_session_id=None - self.subject=None - self.time_to_live=datetime.timedelta(days=14) - self.partition_key=None - self.scheduled_enqueue_time_utc=None - self.auto_renew_error=None, - self.dead_letter_error_description=None - self.dead_letter_reason=None - self.dead_letter_source=None - self.delivery_count=13 - self.enqueued_sequence_number=0 - self.enqueued_time_utc=datetime.datetime(2021, 7, 22, 22, 27, 41, 236000) - self.expires_at_utc=datetime.datetime(2021, 8, 5, 22, 27, 41, 236000) - self.sequence_number=11219 - self.lock_token='233146e3-d5a6-45eb-826f-691d82fb8b13' + self.body = body + self.application_properties = None + self.session_id = None + self.message_id = "3f6c5441-5be5-4f33-80c3-3ffeb6a090ce" + self.content_type = "application/cloudevents+json; charset=utf-8" + self.correlation_id = None + self.to = None + self.reply_to = None + self.reply_to_session_id = None + self.subject = None + self.time_to_live = datetime.timedelta(days=14) + self.partition_key = None + self.scheduled_enqueue_time_utc = None + self.auto_renew_error = (None,) + self.dead_letter_error_description = None + self.dead_letter_reason = None + self.dead_letter_source = None + self.delivery_count = 13 + self.enqueued_sequence_number = 0 + self.enqueued_time_utc = datetime.datetime(2021, 7, 22, 22, 27, 41, 236000) + self.expires_at_utc = datetime.datetime(2021, 8, 5, 22, 27, 41, 236000) + self.sequence_number = 11219 + self.lock_token = "233146e3-d5a6-45eb-826f-691d82fb8b13" + class MockEventhubData(object): def __init__(self, body=None): @@ -55,7 +58,7 @@ def __init__(self, body=None): raise ValueError("EventData cannot be None.") # Internal usage only for transforming AmqpAnnotatedMessage to outgoing EventData - self.body=body + self.body = body self._raw_amqp_message = "some amqp data" self.message_id = None self.content_type = None @@ -68,21 +71,22 @@ def __init__(self, data=None): def __iter__(self): return self - + def __next__(self): if not self.data: return """{"id":"f208feff-099b-4bda-a341-4afd0fa02fef","source":"https://egsample.dev/sampleevent","data":"ServiceBus","type":"Azure.Sdk.Sample","time":"2021-07-22T22:27:38.960209Z","specversion":"1.0"}""" return self.data - + next = __next__ + class MockEhBody(object): def __init__(self, data=None): self.data = data def __iter__(self): return self - + def __next__(self): if not self.data: return b'[{"id":"f208feff-099b-4bda-a341-4afd0fa02fef","source":"https://egsample.dev/sampleevent","data":"Eventhub","type":"Azure.Sdk.Sample","time":"2021-07-22T22:27:38.960209Z","specversion":"1.0"}]' @@ -93,106 +97,96 @@ def __next__(self): # Cloud Event tests def test_cloud_event_constructor(): - event = CloudEvent( - source='Azure.Core.Sample', - type='SampleType', - data='cloudevent' - ) - - assert event.specversion == '1.0' + event = CloudEvent(source="Azure.Core.Sample", type="SampleType", data="cloudevent") + + assert event.specversion == "1.0" assert event.time.__class__ == datetime.datetime assert event.id is not None - assert event.source == 'Azure.Core.Sample' - assert event.data == 'cloudevent' + assert event.source == "Azure.Core.Sample" + assert event.data == "cloudevent" + def test_cloud_event_constructor_unexpected_keyword(): with pytest.raises(ValueError) as e: event = CloudEvent( - source='Azure.Core.Sample', - type='SampleType', - data='cloudevent', + source="Azure.Core.Sample", + type="SampleType", + data="cloudevent", unexpected_keyword="not allowed", - another_bad_kwarg="not allowed either" - ) + another_bad_kwarg="not allowed either", + ) assert "unexpected_keyword" in e assert "another_bad_kwarg" in e + def test_cloud_event_constructor_blank_data(): - event = CloudEvent( - source='Azure.Core.Sample', - type='SampleType', - data='' - ) - - assert event.specversion == '1.0' + event = CloudEvent(source="Azure.Core.Sample", type="SampleType", data="") + + assert event.specversion == "1.0" assert event.time.__class__ == datetime.datetime assert event.id is not None - assert event.source == 'Azure.Core.Sample' - assert event.data == '' + assert event.source == "Azure.Core.Sample" + assert event.data == "" + def test_cloud_event_constructor_NULL_data(): - event = CloudEvent( - source='Azure.Core.Sample', - type='SampleType', - data=NULL - ) + event = CloudEvent(source="Azure.Core.Sample", type="SampleType", data=NULL) assert event.data == NULL assert event.data is NULL + def test_cloud_event_constructor_none_data(): - event = CloudEvent( - source='Azure.Core.Sample', - type='SampleType', - data=None - ) + event = CloudEvent(source="Azure.Core.Sample", type="SampleType", data=None) assert event.data == None + def test_cloud_event_constructor_missing_data(): event = CloudEvent( - source='Azure.Core.Sample', - type='SampleType', - ) - + source="Azure.Core.Sample", + type="SampleType", + ) + assert event.data == None assert event.datacontenttype == None assert event.dataschema == None assert event.subject == None + def test_cloud_storage_dict(): cloud_storage_dict = { - "id":"a0517898-9fa4-4e70-b4a3-afda1dd68672", - "source":"/subscriptions/{subscription-id}/resourceGroups/{resource-group}/providers/Microsoft.Storage/storageAccounts/{storage-account}", - "data":{ - "api":"PutBlockList", - "client_request_id":"6d79dbfb-0e37-4fc4-981f-442c9ca65760", - "request_id":"831e1650-001e-001b-66ab-eeb76e000000", - "e_tag":"0x8D4BCC2E4835CD0", - "content_type":"application/octet-stream", - "content_length":524288, - "blob_type":"BlockBlob", - "url":"https://oc2d2817345i60006.blob.core.windows.net/oc2d2817345i200097container/oc2d2817345i20002296blob", - "sequencer":"00000000000004420000000000028963", - "storage_diagnostics":{"batchId":"b68529f3-68cd-4744-baa4-3c0498ec19f0"} + "id": "a0517898-9fa4-4e70-b4a3-afda1dd68672", + "source": "/subscriptions/{subscription-id}/resourceGroups/{resource-group}/providers/Microsoft.Storage/storageAccounts/{storage-account}", + "data": { + "api": "PutBlockList", + "client_request_id": "6d79dbfb-0e37-4fc4-981f-442c9ca65760", + "request_id": "831e1650-001e-001b-66ab-eeb76e000000", + "e_tag": "0x8D4BCC2E4835CD0", + "content_type": "application/octet-stream", + "content_length": 524288, + "blob_type": "BlockBlob", + "url": "https://oc2d2817345i60006.blob.core.windows.net/oc2d2817345i200097container/oc2d2817345i20002296blob", + "sequencer": "00000000000004420000000000028963", + "storage_diagnostics": {"batchId": "b68529f3-68cd-4744-baa4-3c0498ec19f0"}, }, - "type":"Microsoft.Storage.BlobCreated", - "time":"2021-02-18T20:18:10.581147898Z", - "specversion":"1.0" + "type": "Microsoft.Storage.BlobCreated", + "time": "2021-02-18T20:18:10.581147898Z", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_storage_dict) assert event.data == { - "api":"PutBlockList", - "client_request_id":"6d79dbfb-0e37-4fc4-981f-442c9ca65760", - "request_id":"831e1650-001e-001b-66ab-eeb76e000000", - "e_tag":"0x8D4BCC2E4835CD0", - "content_type":"application/octet-stream", - "content_length":524288, - "blob_type":"BlockBlob", - "url":"https://oc2d2817345i60006.blob.core.windows.net/oc2d2817345i200097container/oc2d2817345i20002296blob", - "sequencer":"00000000000004420000000000028963", - "storage_diagnostics":{"batchId":"b68529f3-68cd-4744-baa4-3c0498ec19f0"} + "api": "PutBlockList", + "client_request_id": "6d79dbfb-0e37-4fc4-981f-442c9ca65760", + "request_id": "831e1650-001e-001b-66ab-eeb76e000000", + "e_tag": "0x8D4BCC2E4835CD0", + "content_type": "application/octet-stream", + "content_length": 524288, + "blob_type": "BlockBlob", + "url": "https://oc2d2817345i60006.blob.core.windows.net/oc2d2817345i200097container/oc2d2817345i20002296blob", + "sequencer": "00000000000004420000000000028963", + "storage_diagnostics": {"batchId": "b68529f3-68cd-4744-baa4-3c0498ec19f0"}, } assert event.specversion == "1.0" assert event.time.__class__ == datetime.datetime @@ -207,14 +201,14 @@ def test_cloud_storage_dict(): def test_cloud_custom_dict_with_extensions(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10.539861122+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10.539861122+00:00", + "specversion": "1.0", "ext1": "example", - "ext2": "example2" + "ext2": "example2", } event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) assert event.data == {"team": "event grid squad"} @@ -225,14 +219,15 @@ def test_cloud_custom_dict_with_extensions(): assert event.time.microsecond == 539861 assert event.extensions == {"ext1": "example", "ext2": "example2"} + def test_cloud_custom_dict_ms_precision_is_gt_six(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10.539861122+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10.539861122+00:00", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) assert event.data == {"team": "event grid squad"} @@ -242,14 +237,15 @@ def test_cloud_custom_dict_ms_precision_is_gt_six(): assert event.time.hour == 20 assert event.time.microsecond == 539861 + def test_cloud_custom_dict_ms_precision_is_lt_six(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10.123+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10.123+00:00", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) assert event.data == {"team": "event grid squad"} @@ -259,14 +255,15 @@ def test_cloud_custom_dict_ms_precision_is_lt_six(): assert event.time.hour == 20 assert event.time.microsecond == 123000 + def test_cloud_custom_dict_ms_precision_is_eq_six(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10.123456+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10.123456+00:00", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) assert event.data == {"team": "event grid squad"} @@ -276,14 +273,15 @@ def test_cloud_custom_dict_ms_precision_is_eq_six(): assert event.time.hour == 20 assert event.time.microsecond == 123456 + def test_cloud_custom_dict_ms_precision_is_gt_six_z_not(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10.539861122Z", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10.539861122Z", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) assert event.data == {"team": "event grid squad"} @@ -293,14 +291,15 @@ def test_cloud_custom_dict_ms_precision_is_gt_six_z_not(): assert event.time.hour == 20 assert event.time.microsecond == 539861 + def test_cloud_custom_dict_ms_precision_is_lt_six_z_not(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10.123Z", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10.123Z", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) assert event.data == {"team": "event grid squad"} @@ -310,14 +309,15 @@ def test_cloud_custom_dict_ms_precision_is_lt_six_z_not(): assert event.time.hour == 20 assert event.time.microsecond == 123000 + def test_cloud_custom_dict_ms_precision_is_eq_six_z_not(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e034", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10.123456Z", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e034", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10.123456Z", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) assert event.data == {"team": "event grid squad"} @@ -327,85 +327,91 @@ def test_cloud_custom_dict_ms_precision_is_eq_six_z_not(): assert event.time.hour == 20 assert event.time.microsecond == 123456 + def test_cloud_custom_dict_blank_data(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":'', - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": "", + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10+00:00", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) - assert event.data == '' + assert event.data == "" assert event.__class__ == CloudEvent + def test_cloud_custom_dict_no_data(): cloud_custom_dict_with_missing_data = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10+00:00", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_missing_data) assert event.__class__ == CloudEvent assert event.data is None + def test_cloud_custom_dict_null_data(): cloud_custom_dict_with_none_data = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "type":"Azure.Sdk.Sample", - "data":None, - "dataschema":None, - "time":"2021-02-18T20:18:10+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "type": "Azure.Sdk.Sample", + "data": None, + "dataschema": None, + "time": "2021-02-18T20:18:10+00:00", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_none_data) assert event.__class__ == CloudEvent assert event.data == NULL assert event.dataschema is NULL + def test_cloud_custom_dict_valid_optional_attrs(): cloud_custom_dict_with_none_data = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "type":"Azure.Sdk.Sample", - "data":None, - "dataschema":"exists", - "time":"2021-02-18T20:18:10+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "type": "Azure.Sdk.Sample", + "data": None, + "dataschema": "exists", + "time": "2021-02-18T20:18:10+00:00", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_with_none_data) assert event.__class__ == CloudEvent assert event.data is NULL assert event.dataschema == "exists" + def test_cloud_custom_dict_both_data_and_base64(): cloud_custom_dict_with_data_and_base64 = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":"abc", - "data_base64":"Y2Wa==", - "type":"Azure.Sdk.Sample", - "time":"2021-02-18T20:18:10+00:00", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": "abc", + "data_base64": "Y2Wa==", + "type": "Azure.Sdk.Sample", + "time": "2021-02-18T20:18:10+00:00", + "specversion": "1.0", } with pytest.raises(ValueError): event = CloudEvent.from_dict(cloud_custom_dict_with_data_and_base64) + def test_cloud_custom_dict_base64(): cloud_custom_dict_base64 = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data_base64":'Y2xvdWRldmVudA==', # cspell:disable-line - "type":"Azure.Sdk.Sample", - "time":"2021-02-23T17:11:13.308772-08:00", - "specversion":"1.0" + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data_base64": "Y2xvdWRldmVudA==", # cspell:disable-line + "type": "Azure.Sdk.Sample", + "time": "2021-02-23T17:11:13.308772-08:00", + "specversion": "1.0", } event = CloudEvent.from_dict(cloud_custom_dict_base64) - assert event.data == b'cloudevent' + assert event.data == b"cloudevent" assert event.specversion == "1.0" assert event.time.hour == 17 assert event.time.minute == 11 @@ -413,59 +419,55 @@ def test_cloud_custom_dict_base64(): assert event.time.tzinfo is not None assert event.__class__ == CloudEvent + def test_data_and_base64_both_exist_raises(): with pytest.raises(ValueError): - CloudEvent.from_dict( - {"source":'sample', - "type":'type', - "data":'data', - "data_base64":'Y2kQ==' - } - ) + CloudEvent.from_dict({"source": "sample", "type": "type", "data": "data", "data_base64": "Y2kQ=="}) + def test_cloud_event_repr(): - event = CloudEvent( - source='Azure.Core.Sample', - type='SampleType', - data='cloudevent' - ) + event = CloudEvent(source="Azure.Core.Sample", type="SampleType", data="cloudevent") assert repr(event).startswith("CloudEvent(source=Azure.Core.Sample, type=SampleType, specversion=1.0,") -def test_extensions_upper_case_value_error(): - with pytest.raises(ValueError): - event = CloudEvent( - source='sample', - type='type', - data='data', - extensions={"lowercase123": "accepted", "NOTlower123": "not allowed"} + +def test_extensions_upper_case_value_error(): + with pytest.raises(ValueError): + event = CloudEvent( + source="sample", + type="type", + data="data", + extensions={"lowercase123": "accepted", "NOTlower123": "not allowed"}, ) -def test_extensions_not_alphanumeric_value_error(): - with pytest.raises(ValueError): - event = CloudEvent( - source='sample', - type='type', - data='data', - extensions={"lowercase123": "accepted", "not@lph@nu^^3ic": "not allowed"} + +def test_extensions_not_alphanumeric_value_error(): + with pytest.raises(ValueError): + event = CloudEvent( + source="sample", + type="type", + data="data", + extensions={"lowercase123": "accepted", "not@lph@nu^^3ic": "not allowed"}, ) + def test_cloud_from_dict_with_invalid_extensions(): cloud_custom_dict_with_extensions = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2020-08-07T02:06:08.11969Z", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2020-08-07T02:06:08.11969Z", + "specversion": "1.0", "ext1": "example", - "BADext2": "example2" + "BADext2": "example2", } with pytest.raises(ValueError): event = CloudEvent.from_dict(cloud_custom_dict_with_extensions) + def test_cloud_custom_dict_ms_precision_is_gt_six(): - time ="2021-02-18T20:18:10.539861122+00:00" + time = "2021-02-18T20:18:10.539861122+00:00" date_obj = _convert_to_isoformat(time) assert date_obj.month == 2 @@ -473,8 +475,9 @@ def test_cloud_custom_dict_ms_precision_is_gt_six(): assert date_obj.hour == 20 assert date_obj.microsecond == 539861 + def test_cloud_custom_dict_ms_precision_is_lt_six(): - time ="2021-02-18T20:18:10.123+00:00" + time = "2021-02-18T20:18:10.123+00:00" date_obj = _convert_to_isoformat(time) assert date_obj.month == 2 @@ -482,8 +485,9 @@ def test_cloud_custom_dict_ms_precision_is_lt_six(): assert date_obj.hour == 20 assert date_obj.microsecond == 123000 + def test_cloud_custom_dict_ms_precision_is_eq_six(): - time ="2021-02-18T20:18:10.123456+00:00" + time = "2021-02-18T20:18:10.123456+00:00" date_obj = _convert_to_isoformat(time) assert date_obj.month == 2 @@ -491,8 +495,9 @@ def test_cloud_custom_dict_ms_precision_is_eq_six(): assert date_obj.hour == 20 assert date_obj.microsecond == 123456 + def test_cloud_custom_dict_ms_precision_is_gt_six_z_not(): - time ="2021-02-18T20:18:10.539861122Z" + time = "2021-02-18T20:18:10.539861122Z" date_obj = _convert_to_isoformat(time) assert date_obj.month == 2 @@ -500,8 +505,9 @@ def test_cloud_custom_dict_ms_precision_is_gt_six_z_not(): assert date_obj.hour == 20 assert date_obj.microsecond == 539861 + def test_cloud_custom_dict_ms_precision_is_lt_six_z_not(): - time ="2021-02-18T20:18:10.123Z" + time = "2021-02-18T20:18:10.123Z" date_obj = _convert_to_isoformat(time) assert date_obj.month == 2 @@ -509,8 +515,9 @@ def test_cloud_custom_dict_ms_precision_is_lt_six_z_not(): assert date_obj.hour == 20 assert date_obj.microsecond == 123000 + def test_cloud_custom_dict_ms_precision_is_eq_six_z_not(): - time ="2021-02-18T20:18:10.123456Z" + time = "2021-02-18T20:18:10.123456Z" date_obj = _convert_to_isoformat(time) assert date_obj.month == 2 @@ -518,40 +525,53 @@ def test_cloud_custom_dict_ms_precision_is_eq_six_z_not(): assert date_obj.hour == 20 assert date_obj.microsecond == 123456 + def test_eventgrid_event_schema_raises(): cloud_custom_dict = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "data":{"team": "event grid squad"}, + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "data": {"team": "event grid squad"}, "dataVersion": "1.0", - "subject":"Azure.Sdk.Sample", - "eventTime":"2020-08-07T02:06:08.11969Z", - "eventType":"pull request", + "subject": "Azure.Sdk.Sample", + "eventTime": "2020-08-07T02:06:08.11969Z", + "eventType": "pull request", } - with pytest.raises(ValueError, match="The event you are trying to parse follows the Eventgrid Schema. You can parse EventGrid events using EventGridEvent.from_dict method in the azure-eventgrid library."): + with pytest.raises( + ValueError, + match="The event you are trying to parse follows the Eventgrid Schema. You can parse EventGrid events using EventGridEvent.from_dict method in the azure-eventgrid library.", + ): CloudEvent.from_dict(cloud_custom_dict) + def test_wrong_schema_raises_no_source(): cloud_custom_dict = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2020-08-07T02:06:08.11969Z", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2020-08-07T02:06:08.11969Z", + "specversion": "1.0", } - with pytest.raises(ValueError, match="The event does not conform to the cloud event spec https://github.com/cloudevents/spec. The `source` and `type` params are required."): + with pytest.raises( + ValueError, + match="The event does not conform to the cloud event spec https://github.com/cloudevents/spec. The `source` and `type` params are required.", + ): CloudEvent.from_dict(cloud_custom_dict) + def test_wrong_schema_raises_no_type(): cloud_custom_dict = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "data":{"team": "event grid squad"}, - "source":"Azure/Sdk/Sample", - "time":"2020-08-07T02:06:08.11969Z", - "specversion":"1.0", + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "data": {"team": "event grid squad"}, + "source": "Azure/Sdk/Sample", + "time": "2020-08-07T02:06:08.11969Z", + "specversion": "1.0", } - with pytest.raises(ValueError, match="The event does not conform to the cloud event spec https://github.com/cloudevents/spec. The `source` and `type` params are required."): + with pytest.raises( + ValueError, + match="The event does not conform to the cloud event spec https://github.com/cloudevents/spec. The `source` and `type` params are required.", + ): CloudEvent.from_dict(cloud_custom_dict) + def test_get_bytes_storage_queue(): cloud_storage_dict = """{ "id":"a0517898-9fa4-4e70-b4a3-afda1dd68672", @@ -575,116 +595,119 @@ def test_get_bytes_storage_queue(): obj = MockQueueMessage(content=cloud_storage_dict) dict = _get_json_content(obj) - assert dict.get('data') == { - "api":"PutBlockList", - "client_request_id":"6d79dbfb-0e37-4fc4-981f-442c9ca65760", - "request_id":"831e1650-001e-001b-66ab-eeb76e000000", - "e_tag":"0x8D4BCC2E4835CD0", - "content_type":"application/octet-stream", - "content_length":524288, - "blob_type":"BlockBlob", - "url":"https://oc2d2817345i60006.blob.core.windows.net/oc2d2817345i200097container/oc2d2817345i20002296blob", - "sequencer":"00000000000004420000000000028963", - "storage_diagnostics":{"batchId":"b68529f3-68cd-4744-baa4-3c0498ec19f0"} - } - assert dict.get('specversion') == "1.0" + assert dict.get("data") == { + "api": "PutBlockList", + "client_request_id": "6d79dbfb-0e37-4fc4-981f-442c9ca65760", + "request_id": "831e1650-001e-001b-66ab-eeb76e000000", + "e_tag": "0x8D4BCC2E4835CD0", + "content_type": "application/octet-stream", + "content_length": 524288, + "blob_type": "BlockBlob", + "url": "https://oc2d2817345i60006.blob.core.windows.net/oc2d2817345i200097container/oc2d2817345i20002296blob", + "sequencer": "00000000000004420000000000028963", + "storage_diagnostics": {"batchId": "b68529f3-68cd-4744-baa4-3c0498ec19f0"}, + } + assert dict.get("specversion") == "1.0" + def test_get_bytes_storage_queue_wrong_content(): - cloud_storage_string = u'This is a random string which must fail' + cloud_storage_string = "This is a random string which must fail" obj = MockQueueMessage(content=cloud_storage_string) with pytest.raises(ValueError, match="Failed to load JSON content from the object."): _get_json_content(obj) + def test_get_bytes_servicebus(): obj = MockServiceBusReceivedMessage( body=MockBody(), - message_id='3f6c5441-5be5-4f33-80c3-3ffeb6a090ce', - content_type='application/cloudevents+json; charset=utf-8', + message_id="3f6c5441-5be5-4f33-80c3-3ffeb6a090ce", + content_type="application/cloudevents+json; charset=utf-8", time_to_live=datetime.timedelta(days=14), delivery_count=13, enqueued_sequence_number=0, enqueued_time_utc=datetime.datetime(2021, 7, 22, 22, 27, 41, 236000), expires_at_utc=datetime.datetime(2021, 8, 5, 22, 27, 41, 236000), sequence_number=11219, - lock_token='233146e3-d5a6-45eb-826f-691d82fb8b13' + lock_token="233146e3-d5a6-45eb-826f-691d82fb8b13", ) dict = _get_json_content(obj) - assert dict.get('data') == "ServiceBus" - assert dict.get('specversion') == '1.0' + assert dict.get("data") == "ServiceBus" + assert dict.get("specversion") == "1.0" + def test_get_bytes_servicebus_wrong_content(): obj = MockServiceBusReceivedMessage( body=MockBody(data="random string"), - message_id='3f6c5441-5be5-4f33-80c3-3ffeb6a090ce', - content_type='application/json; charset=utf-8', + message_id="3f6c5441-5be5-4f33-80c3-3ffeb6a090ce", + content_type="application/json; charset=utf-8", time_to_live=datetime.timedelta(days=14), delivery_count=13, enqueued_sequence_number=0, enqueued_time_utc=datetime.datetime(2021, 7, 22, 22, 27, 41, 236000), expires_at_utc=datetime.datetime(2021, 8, 5, 22, 27, 41, 236000), sequence_number=11219, - lock_token='233146e3-d5a6-45eb-826f-691d82fb8b13' + lock_token="233146e3-d5a6-45eb-826f-691d82fb8b13", ) with pytest.raises(ValueError, match="Failed to load JSON content from the object."): _get_json_content(obj) + def test_get_bytes_eventhubs(): - obj = MockEventhubData( - body=MockEhBody() - ) + obj = MockEventhubData(body=MockEhBody()) dict = _get_json_content(obj) - assert dict.get('data') == 'Eventhub' - assert dict.get('specversion') == '1.0' + assert dict.get("data") == "Eventhub" + assert dict.get("specversion") == "1.0" + def test_get_bytes_eventhubs_wrong_content(): - obj = MockEventhubData( - body=MockEhBody(data='random string') - ) + obj = MockEventhubData(body=MockEhBody(data="random string")) with pytest.raises(ValueError, match="Failed to load JSON content from the object."): dict = _get_json_content(obj) + def test_get_bytes_random_obj(): json_str = '{"id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", "source": "https://egtest.dev/cloudcustomevent", "data": {"team": "event grid squad"}, "type": "Azure.Sdk.Sample", "time": "2020-08-07T02:06:08.11969Z", "specversion": "1.0"}' - random_obj = { - "id":"de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", - "source":"https://egtest.dev/cloudcustomevent", - "data":{"team": "event grid squad"}, - "type":"Azure.Sdk.Sample", - "time":"2020-08-07T02:06:08.11969Z", - "specversion":"1.0" + random_obj = { + "id": "de0fd76c-4ef4-4dfb-ab3a-8f24a307e033", + "source": "https://egtest.dev/cloudcustomevent", + "data": {"team": "event grid squad"}, + "type": "Azure.Sdk.Sample", + "time": "2020-08-07T02:06:08.11969Z", + "specversion": "1.0", } assert _get_json_content(json_str) == random_obj + def test_from_json_sb(): obj = MockServiceBusReceivedMessage( body=MockBody(), - message_id='3f6c5441-5be5-4f33-80c3-3ffeb6a090ce', - content_type='application/cloudevents+json; charset=utf-8', + message_id="3f6c5441-5be5-4f33-80c3-3ffeb6a090ce", + content_type="application/cloudevents+json; charset=utf-8", time_to_live=datetime.timedelta(days=14), delivery_count=13, enqueued_sequence_number=0, enqueued_time_utc=datetime.datetime(2021, 7, 22, 22, 27, 41, 236000), expires_at_utc=datetime.datetime(2021, 8, 5, 22, 27, 41, 236000), sequence_number=11219, - lock_token='233146e3-d5a6-45eb-826f-691d82fb8b13' + lock_token="233146e3-d5a6-45eb-826f-691d82fb8b13", ) event = CloudEvent.from_json(obj) assert event.id == "f208feff-099b-4bda-a341-4afd0fa02fef" assert event.data == "ServiceBus" + def test_from_json_eh(): - obj = MockEventhubData( - body=MockEhBody() - ) + obj = MockEventhubData(body=MockEhBody()) event = CloudEvent.from_json(obj) assert event.id == "f208feff-099b-4bda-a341-4afd0fa02fef" assert event.data == "Eventhub" + def test_from_json_storage(): cloud_storage_dict = """{ "id":"a0517898-9fa4-4e70-b4a3-afda1dd68672", @@ -708,17 +731,17 @@ def test_from_json_storage(): obj = MockQueueMessage(content=cloud_storage_dict) event = CloudEvent.from_json(obj) assert event.data == { - "api":"PutBlockList", - "client_request_id":"6d79dbfb-0e37-4fc4-981f-442c9ca65760", - "request_id":"831e1650-001e-001b-66ab-eeb76e000000", - "e_tag":"0x8D4BCC2E4835CD0", - "content_type":"application/octet-stream", - "content_length":524288, - "blob_type":"BlockBlob", - "url":"https://oc2d2817345i60006.blob.core.windows.net/oc2d2817345i200097container/oc2d2817345i20002296blob", - "sequencer":"00000000000004420000000000028963", - "storage_diagnostics":{"batchId":"b68529f3-68cd-4744-baa4-3c0498ec19f0"} - } + "api": "PutBlockList", + "client_request_id": "6d79dbfb-0e37-4fc4-981f-442c9ca65760", + "request_id": "831e1650-001e-001b-66ab-eeb76e000000", + "e_tag": "0x8D4BCC2E4835CD0", + "content_type": "application/octet-stream", + "content_length": 524288, + "blob_type": "BlockBlob", + "url": "https://oc2d2817345i60006.blob.core.windows.net/oc2d2817345i200097container/oc2d2817345i20002296blob", + "sequencer": "00000000000004420000000000028963", + "storage_diagnostics": {"batchId": "b68529f3-68cd-4744-baa4-3c0498ec19f0"}, + } def test_from_json(): diff --git a/sdk/core/azure-core/tests/test_paging.py b/sdk/core/azure-core/tests/test_paging.py index decf3fae94f9..7057a75aee53 100644 --- a/sdk/core/azure-core/tests/test_paging.py +++ b/sdk/core/azure-core/tests/test_paging.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from azure.core.paging import ItemPaged from azure.core.exceptions import HttpResponseError @@ -31,101 +31,74 @@ class TestPaging(object): - def test_basic_paging(self): - def get_next(continuation_token=None): - """Simplify my life and return JSON and not response, but should be response. - """ + """Simplify my life and return JSON and not response, but should be response.""" if not continuation_token: - return { - 'nextLink': 'page2', - 'value': ['value1.0', 'value1.1'] - } + return {"nextLink": "page2", "value": ["value1.0", "value1.1"]} else: - return { - 'nextLink': None, - 'value': ['value2.0', 'value2.1'] - } + return {"nextLink": None, "value": ["value2.0", "value2.1"]} def extract_data(response): - return response['nextLink'], iter(response['value']) + return response["nextLink"], iter(response["value"]) pager = ItemPaged(get_next, extract_data) result_iterated = list(pager) - assert ['value1.0', 'value1.1', 'value2.0', 'value2.1'] == result_iterated + assert ["value1.0", "value1.1", "value2.0", "value2.1"] == result_iterated def test_by_page_paging(self): - def get_next(continuation_token=None): - """Simplify my life and return JSON and not response, but should be response. - """ + """Simplify my life and return JSON and not response, but should be response.""" if not continuation_token: - return { - 'nextLink': 'page2', - 'value': ['value1.0', 'value1.1'] - } + return {"nextLink": "page2", "value": ["value1.0", "value1.1"]} else: - return { - 'nextLink': None, - 'value': ['value2.0', 'value2.1'] - } + return {"nextLink": None, "value": ["value2.0", "value2.1"]} def extract_data(response): - return response['nextLink'], iter(response['value']) + return response["nextLink"], iter(response["value"]) pager = ItemPaged(get_next, extract_data).by_page() page1 = next(pager) - assert list(page1) == ['value1.0', 'value1.1'] + assert list(page1) == ["value1.0", "value1.1"] page2 = next(pager) - assert list(page2) == ['value2.0', 'value2.1'] + assert list(page2) == ["value2.0", "value2.1"] with pytest.raises(StopIteration): next(pager) def test_advance_paging(self): - def get_next(continuation_token=None): - """Simplify my life and return JSON and not response, but should be response. - """ + """Simplify my life and return JSON and not response, but should be response.""" if not continuation_token: - return { - 'nextLink': 'page2', - 'value': ['value1.0', 'value1.1'] - } + return {"nextLink": "page2", "value": ["value1.0", "value1.1"]} else: - return { - 'nextLink': None, - 'value': ['value2.0', 'value2.1'] - } + return {"nextLink": None, "value": ["value2.0", "value2.1"]} def extract_data(response): - return response['nextLink'], iter(response['value']) + return response["nextLink"], iter(response["value"]) pager = ItemPaged(get_next, extract_data) page1 = next(pager) - assert page1 == 'value1.0' + assert page1 == "value1.0" page1 = next(pager) - assert page1 == 'value1.1' + assert page1 == "value1.1" page2 = next(pager) - assert page2 == 'value2.0' + assert page2 == "value2.0" page2 = next(pager) - assert page2 == 'value2.1' + assert page2 == "value2.1" with pytest.raises(StopIteration): next(pager) def test_none_value(self): def get_next(continuation_token=None): - return { - 'nextLink': None, - 'value': None - } + return {"nextLink": None, "value": None} + def extract_data(response): - return response['nextLink'], iter(response['value'] or []) + return response["nextLink"], iter(response["value"] or []) pager = ItemPaged(get_next, extract_data) result_iterated = list(pager) @@ -133,31 +106,27 @@ def extract_data(response): def test_print(self): def get_next(continuation_token=None): - return { - 'nextLink': None, - 'value': None - } + return {"nextLink": None, "value": None} + def extract_data(response): - return response['nextLink'], iter(response['value'] or []) + return response["nextLink"], iter(response["value"] or []) pager = ItemPaged(get_next, extract_data) output = repr(pager) - assert output.startswith('\n" + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_url_with_params(http_request): @@ -246,22 +267,25 @@ def test_request_url_with_params(http_request): assert request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_url_with_params_as_list(http_request): request = http_request("GET", "/") request.url = "a/b/c?t=y" - request.format_parameters({"g": ["h","i"]}) + request.format_parameters({"g": ["h", "i"]}) assert request.url in ["a/b/c?g=h&g=i&t=y", "a/b/c?t=y&g=h&g=i"] + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_url_with_params_with_none_in_list(http_request): request = http_request("GET", "/") request.url = "a/b/c?t=y" with pytest.raises(ValueError): - request.format_parameters({"g": ["h",None]}) + request.format_parameters({"g": ["h", None]}) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_url_with_params_with_none(http_request): @@ -271,19 +295,21 @@ def test_request_url_with_params_with_none(http_request): with pytest.raises(ValueError): request.format_parameters({"g": None}) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_repr(http_request): request = http_request("GET", "hello.com") assert repr(request) == "" + def test_add_custom_policy(): class BooPolicy(HTTPPolicy): def send(*args): - raise AzureError('boo') + raise AzureError("boo") class FooPolicy(HTTPPolicy): def send(*args): - raise AzureError('boo') + raise AzureError("boo") config = Configuration() retry_policy = RetryPolicy() @@ -328,8 +354,9 @@ def send(*args): assert pos_boo < pos_retry assert pos_foo > pos_retry - client = PipelineClient(base_url="test", config=config, per_call_policies=[boo_policy], - per_retry_policies=[foo_policy]) + client = PipelineClient( + base_url="test", config=config, per_call_policies=[boo_policy], per_retry_policies=[foo_policy] + ) policies = client._pipeline._impl_policies assert boo_policy in policies assert foo_policy in policies @@ -339,9 +366,7 @@ def send(*args): assert pos_boo < pos_retry assert pos_foo > pos_retry - policies = [UserAgentPolicy(), - RetryPolicy(), - DistributedTracingPolicy()] + policies = [UserAgentPolicy(), RetryPolicy(), DistributedTracingPolicy()] client = PipelineClient(base_url="test", policies=policies, per_call_policies=boo_policy) actual_policies = client._pipeline._impl_policies assert boo_policy == actual_policies[0] @@ -356,33 +381,32 @@ def send(*args): actual_policies = client._pipeline._impl_policies assert foo_policy == actual_policies[2] - client = PipelineClient(base_url="test", policies=policies, per_call_policies=boo_policy, - per_retry_policies=foo_policy) + client = PipelineClient( + base_url="test", policies=policies, per_call_policies=boo_policy, per_retry_policies=foo_policy + ) actual_policies = client._pipeline._impl_policies assert boo_policy == actual_policies[0] assert foo_policy == actual_policies[3] - client = PipelineClient(base_url="test", policies=policies, per_call_policies=[boo_policy], - per_retry_policies=[foo_policy]) + client = PipelineClient( + base_url="test", policies=policies, per_call_policies=[boo_policy], per_retry_policies=[foo_policy] + ) actual_policies = client._pipeline._impl_policies assert boo_policy == actual_policies[0] assert foo_policy == actual_policies[3] - policies = [UserAgentPolicy(), - DistributedTracingPolicy()] + policies = [UserAgentPolicy(), DistributedTracingPolicy()] with pytest.raises(ValueError): client = PipelineClient(base_url="test", policies=policies, per_retry_policies=foo_policy) with pytest.raises(ValueError): client = PipelineClient(base_url="test", policies=policies, per_retry_policies=[foo_policy]) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_basic_requests(port, http_request): conf = Configuration() request = http_request("GET", "http://localhost:{}/basic/string".format(port)) - policies = [ - UserAgentPolicy("myusergant"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myusergant"), RedirectPolicy()] with Pipeline(RequestsTransport(), policies=policies) as pipeline: response = pipeline.run(request) if is_rest(request): @@ -391,14 +415,12 @@ def test_basic_requests(port, http_request): assert pipeline._transport.session is None assert isinstance(response.http_response.status_code, int) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_basic_options_requests(port, http_request): request = http_request("OPTIONS", "http://localhost:{}/basic/string".format(port)) - policies = [ - UserAgentPolicy("myusergant"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myusergant"), RedirectPolicy()] with Pipeline(RequestsTransport(), policies=policies) as pipeline: response = pipeline.run(request) if is_rest(request): @@ -407,15 +429,13 @@ def test_basic_options_requests(port, http_request): assert pipeline._transport.session is None assert isinstance(response.http_response.status_code, int) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_basic_requests_separate_session(port, http_request): session = requests.Session() request = http_request("GET", "http://localhost:{}/basic/string".format(port)) - policies = [ - UserAgentPolicy("myusergant"), - RedirectPolicy() - ] + policies = [UserAgentPolicy("myusergant"), RedirectPolicy()] transport = RequestsTransport(session=session, session_owner=False) with Pipeline(transport, policies=policies) as pipeline: response = pipeline.run(request) @@ -428,28 +448,22 @@ def test_basic_requests_separate_session(port, http_request): assert transport.session transport.session.close() + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_text(port, http_request): client = PipelineClientBase("http://localhost:{}".format(port)) if is_rest(http_request): request = http_request("GET", "/", json="foo") else: - request = client.get( - "/", - content="foo" - ) + request = client.get("/", content="foo") # In absence of information, everything is JSON (double quote added) assert request.data == json.dumps("foo") if is_rest(http_request): - request = http_request("POST", "/", headers={'content-type': 'text/whatever'}, content="foo") + request = http_request("POST", "/", headers={"content-type": "text/whatever"}, content="foo") else: - request = client.post( - "/", - headers={'content-type': 'text/whatever'}, - content="foo" - ) + request = client.post("/", headers={"content-type": "text/whatever"}, content="foo") # We want a direct string assert request.data == "foo" diff --git a/sdk/core/azure-core/tests/test_polling.py b/sdk/core/azure-core/tests/test_polling.py index 79549daadd1d..e3fb92e124b6 100644 --- a/sdk/core/azure-core/tests/test_polling.py +++ b/sdk/core/azure-core/tests/test_polling.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,8 +22,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import time + try: from unittest import mock except ImportError: @@ -34,9 +35,8 @@ from azure.core import PipelineClient from azure.core.exceptions import ServiceResponseError from azure.core.polling import * -from azure.core.polling.base_polling import ( - LROBasePolling, LocationPolling -) +from azure.core.polling.base_polling import LROBasePolling, LocationPolling + # from msrest.serialization import Model @@ -75,46 +75,45 @@ def test_no_polling(client): no_polling = NoPolling() initial_response = "initial response" + def deserialization_cb(response): assert response == initial_response - return "Treated: "+response + return "Treated: " + response no_polling.initialize(client, initial_response, deserialization_cb) - no_polling.run() # Should no raise and do nothing + no_polling.run() # Should no raise and do nothing assert no_polling.status() == "succeeded" assert no_polling.finished() - assert no_polling.resource() == "Treated: "+initial_response + assert no_polling.resource() == "Treated: " + initial_response continuation_token = no_polling.get_continuation_token() assert isinstance(continuation_token, str) no_polling_revived_args = NoPolling.from_continuation_token( - continuation_token, - deserialization_callback=deserialization_cb, - client=client + continuation_token, deserialization_callback=deserialization_cb, client=client ) no_polling_revived = NoPolling() no_polling_revived.initialize(*no_polling_revived_args) assert no_polling_revived.status() == "succeeded" assert no_polling_revived.finished() - assert no_polling_revived.resource() == "Treated: "+initial_response + assert no_polling_revived.resource() == "Treated: " + initial_response + def test_polling_with_path_format_arguments(client): - method = LROBasePolling( - timeout=0, - path_format_arguments={"host": "host:3000", "accountName": "local"} - ) + method = LROBasePolling(timeout=0, path_format_arguments={"host": "host:3000", "accountName": "local"}) client._base_url = "http://{accountName}{host}" method._operation = LocationPolling() method._operation._location_url = "/results/1" method._client = client - assert "http://localhost:3000/results/1" == method._client.format_url(method._operation.get_polling_url(), **method._path_format_arguments) + assert "http://localhost:3000/results/1" == method._client.format_url( + method._operation.get_polling_url(), **method._path_format_arguments + ) class PollingTwoSteps(PollingMethod): - """An empty poller that returns the deserialized initial response. - """ + """An empty poller that returns the deserialized initial response.""" + def __init__(self, sleep=0): self._initial_response = None self._deserialization_callback = None @@ -126,10 +125,9 @@ def initialize(self, _, initial_response, deserialization_callback): self._finished = False def run(self): - """Empty run, no polling. - """ + """Empty run, no polling.""" self._finished = True - time.sleep(self._sleep) # Give me time to add callbacks! + time.sleep(self._sleep) # Give me time to add callbacks! def status(self): """Return the current status as a string. @@ -153,7 +151,7 @@ def get_continuation_token(self): def from_continuation_token(cls, continuation_token, **kwargs): # type(str, Any) -> Tuple initial_response = continuation_token - deserialization_callback = kwargs['deserialization_callback'] + deserialization_callback = kwargs["deserialization_callback"] return None, initial_response, deserialization_callback @@ -165,7 +163,7 @@ def test_poller(client): # Same for deserialization_callback, just pass to the polling_method def deserialization_callback(response): assert response == initial_response - return "Treated: "+response + return "Treated: " + response method = NoPolling() @@ -176,7 +174,7 @@ def deserialization_callback(response): result = poller.result() assert poller.done() - assert result == "Treated: "+initial_response + assert result == "Treated: " + initial_response assert poller.status() == "succeeded" assert poller.polling_method() is method done_cb.assert_called_once_with(method) @@ -195,7 +193,7 @@ def deserialization_callback(response): poller.remove_done_callback(done_cb2) result = poller.result() - assert result == "Treated: "+initial_response + assert result == "Treated: " + initial_response assert poller.status() == "succeeded" done_cb.assert_called_once_with(method) done_cb2.assert_not_called() @@ -213,22 +211,22 @@ def deserialization_callback(response): client=client, initial_response=initial_response, deserialization_callback=deserialization_callback, - polling_method=method + polling_method=method, ) result = new_poller.result() - assert result == "Treated: "+initial_response + assert result == "Treated: " + initial_response assert new_poller.status() == "succeeded" def test_broken_poller(client): - class NoPollingError(PollingTwoSteps): def run(self): raise ValueError("Something bad happened") initial_response = "Initial response" + def deserialization_callback(response): - return "Treated: "+response + return "Treated: " + response method = NoPollingError() poller = LROPoller(client, initial_response, deserialization_callback, method) @@ -239,14 +237,14 @@ def deserialization_callback(response): def test_poller_error_continuation(client): - class NoPollingError(PollingTwoSteps): def run(self): raise ServiceResponseError("Something bad happened") initial_response = "Initial response" + def deserialization_callback(response): - return "Treated: "+response + return "Treated: " + response method = NoPollingError() poller = LROPoller(client, initial_response, deserialization_callback, method) diff --git a/sdk/core/azure-core/tests/test_request_id_policy.py b/sdk/core/azure-core/tests/test_request_id_policy.py index 44c8647982cb..e3c383ab746d 100644 --- a/sdk/core/azure-core/tests/test_request_id_policy.py +++ b/sdk/core/azure-core/tests/test_request_id_policy.py @@ -5,6 +5,7 @@ """Tests for the request id policy.""" from azure.core.pipeline.policies import RequestIdPolicy from azure.core.pipeline import PipelineRequest, PipelineContext + try: from unittest import mock except ImportError: @@ -17,24 +18,29 @@ request_id_init_values = ("foo", None, "_unset") request_id_set_values = ("bar", None, "_unset") request_id_req_values = ("baz", None, "_unset") -full_combination = list(product(auto_request_id_values, request_id_init_values, request_id_set_values, request_id_req_values, HTTP_REQUESTS)) +full_combination = list( + product(auto_request_id_values, request_id_init_values, request_id_set_values, request_id_req_values, HTTP_REQUESTS) +) + -@pytest.mark.parametrize("auto_request_id, request_id_init, request_id_set, request_id_req, http_request", full_combination) +@pytest.mark.parametrize( + "auto_request_id, request_id_init, request_id_set, request_id_req, http_request", full_combination +) def test_request_id_policy(auto_request_id, request_id_init, request_id_set, request_id_req, http_request): """Test policy with no other policy and happy path""" kwargs = {} if auto_request_id is not None: - kwargs['auto_request_id'] = auto_request_id + kwargs["auto_request_id"] = auto_request_id if request_id_init != "_unset": - kwargs['request_id'] = request_id_init + kwargs["request_id"] = request_id_init request_id_policy = RequestIdPolicy(**kwargs) if request_id_set != "_unset": request_id_policy.set_request_id(request_id_set) - request = http_request('GET', 'http://localhost/') + request = http_request("GET", "http://localhost/") pipeline_request = PipelineRequest(request, PipelineContext(None)) if request_id_req != "_unset": - pipeline_request.context.options['request_id'] = request_id_req - with mock.patch('uuid.uuid1', return_value="VALUE"): + pipeline_request.context.options["request_id"] = request_id_req + with mock.patch("uuid.uuid1", return_value="VALUE"): request_id_policy.on_request(pipeline_request) assert all(v is not None for v in request.headers.values()) @@ -55,11 +61,12 @@ def test_request_id_policy(auto_request_id, request_id_init, request_id_set, req else: assert not "x-ms-client-request-id" in request.headers + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_id_already_exists(http_request): """Test policy with no other policy and happy path""" request_id_policy = RequestIdPolicy() - request = http_request('GET', 'http://localhost/') + request = http_request("GET", "http://localhost/") request.headers["x-ms-client-request-id"] = "VALUE" pipeline_request = PipelineRequest(request, PipelineContext(None)) request_id_policy.on_request(pipeline_request) diff --git a/sdk/core/azure-core/tests/test_requests_universal.py b/sdk/core/azure-core/tests/test_requests_universal.py index 80dbf66299f8..880c414598ea 100644 --- a/sdk/core/azure-core/tests/test_requests_universal.py +++ b/sdk/core/azure-core/tests/test_requests_universal.py @@ -46,13 +46,15 @@ def thread_body(local_sender): future = executor.submit(thread_body, sender) assert future.result() + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_requests_auto_headers(port, http_request): request = http_request("POST", "http://localhost:{}/basic/string".format(port)) with RequestsTransport() as sender: response = sender.send(request) auto_headers = response.internal_response.request.headers - assert 'Content-Type' not in auto_headers + assert "Content-Type" not in auto_headers + def _create_requests_response(http_response, body_bytes, headers=None): # https://github.com/psf/requests/blob/67a7b2e8336951d527e223429672354989384197/requests/adapters.py#L255 @@ -60,40 +62,30 @@ def _create_requests_response(http_response, body_bytes, headers=None): req_response._content = body_bytes req_response._content_consumed = True req_response.status_code = 200 - req_response.reason = 'OK' + req_response.reason = "OK" if headers: # req_response.headers is type CaseInsensitiveDict req_response.headers.update(headers) req_response.encoding = requests.utils.get_encoding_from_headers(req_response.headers) - response = create_transport_response( - http_response, - None, # Don't need a request here - req_response - ) + response = create_transport_response(http_response, None, req_response) # Don't need a request here return response + @pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) def test_requests_response_text(http_response): for encoding in ["utf-8", "utf-8-sig", None]: - res = _create_requests_response( - http_response, - b'\xef\xbb\xbf56', - {'Content-Type': 'text/plain'} - ) + res = _create_requests_response(http_response, b"\xef\xbb\xbf56", {"Content-Type": "text/plain"}) if is_rest(http_response): res.read() - assert res.text(encoding) == '56', "Encoding {} didn't work".format(encoding) + assert res.text(encoding) == "56", "Encoding {} didn't work".format(encoding) + @pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) def test_repr(http_response): - res = _create_requests_response( - http_response, - b'\xef\xbb\xbf56', - {'Content-Type': 'text/plain'} - ) + res = _create_requests_response(http_response, b"\xef\xbb\xbf56", {"Content-Type": "text/plain"}) class_name = "HttpResponse" if is_rest(http_response) else "RequestsTransportResponse" assert repr(res) == "<{}: 200 OK, Content-Type: text/plain>".format(class_name) diff --git a/sdk/core/azure-core/tests/test_rest_context_manager.py b/sdk/core/azure-core/tests/test_rest_context_manager.py index 0531cfe1505c..194204239617 100644 --- a/sdk/core/azure-core/tests/test_rest_context_manager.py +++ b/sdk/core/azure-core/tests/test_rest_context_manager.py @@ -8,11 +8,13 @@ from azure.core.rest import HttpRequest from azure.core.exceptions import ResponseNotReadError + def test_normal_call(client, port): def _raise_and_get_text(response): response.raise_for_status() assert response.text() == "Hello, world!" assert response.is_closed + request = HttpRequest("GET", url="/basic/string") response = client.send_request(request) _raise_and_get_text(response) @@ -25,6 +27,7 @@ def _raise_and_get_text(response): with response as response: _raise_and_get_text(response) + def test_stream_call(client): def _raise_and_get_text(response): response.raise_for_status() @@ -34,6 +37,7 @@ def _raise_and_get_text(response): response.read() assert response.text() == "Hello, world!" assert response.is_closed + request = HttpRequest("GET", url="/streams/basic") response = client.send_request(request, stream=True) _raise_and_get_text(response) @@ -47,6 +51,7 @@ def _raise_and_get_text(response): with response as response: _raise_and_get_text(response) + # TODO: commenting until https://github.com/Azure/azure-sdk-for-python/issues/18086 is fixed # def test_stream_with_error(client): diff --git a/sdk/core/azure-core/tests/test_rest_headers.py b/sdk/core/azure-core/tests/test_rest_headers.py index 7bbeea9ee000..e5d170d1f738 100644 --- a/sdk/core/azure-core/tests/test_rest_headers.py +++ b/sdk/core/azure-core/tests/test_rest_headers.py @@ -12,13 +12,16 @@ # Thank you httpx for your wonderful tests! from azure.core.rest import HttpRequest + @pytest.fixture def get_request_headers(): def _get_request_headers(header_value): request = HttpRequest(method="GET", url="http://example.org", headers=header_value) return request.headers + return _get_request_headers + # flask returns these response headers, which we don't really need for these following tests RESPONSE_HEADERS_TO_IGNORE = [ "Connection", @@ -28,6 +31,7 @@ def _get_request_headers(header_value): "Date", ] + @pytest.fixture def get_response_headers(client): def _get_response_headers(request): @@ -36,13 +40,16 @@ def _get_response_headers(request): for header in RESPONSE_HEADERS_TO_IGNORE: response.headers.pop(header, None) return response.headers + return _get_response_headers + def test_headers_request(get_request_headers): h = get_request_headers({"a": "123", "b": "789"}) assert h["A"] == "123" assert h["B"] == "789" + def test_headers_response(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) assert "a" in h @@ -58,7 +65,7 @@ def test_headers_response(get_response_headers): assert h.get("nope", default="default") is "default" assert h.get("nope", default=None) is None assert h.get("nope", default=[]) == [] - assert list(h) == ['a', 'b'] + assert list(h) == ["a", "b"] assert set(h.keys()) == set(["a", "b"]) assert list(h.values()) == ["123, 456", "789"] @@ -66,6 +73,7 @@ def test_headers_response(get_response_headers): assert list(h) == ["a", "b"] assert dict(h) == {"a": "123, 456", "b": "789"} + def test_headers_response_keys(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) # basically want to make sure this behaves like dict {"a": "123, 456", "b": "789"} @@ -81,8 +89,9 @@ def test_headers_response_keys_mutability(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) # test mutability before_mutation_keys = h.keys() - h['c'] = '000' - assert 'c' in before_mutation_keys + h["c"] = "000" + assert "c" in before_mutation_keys + def test_headers_response_values(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) @@ -90,8 +99,8 @@ def test_headers_response_values(get_response_headers): ref_dict = {"a": "123, 456", "b": "789"} assert set(h.values()) == set(ref_dict.values()) assert repr(h.values()) == "ValuesView({'a': '123, 456', 'b': '789'})" - assert '123, 456' in h.values() - assert '789' in h.values() + assert "123, 456" in h.values() + assert "789" in h.values() assert set(h.values()) == set(ref_dict.values()) @@ -99,8 +108,8 @@ def test_headers_response_values_mutability(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) # test mutability before_mutation_values = h.values() - h['c'] = '000' - assert '000' in before_mutation_values + h["c"] = "000" + assert "000" in before_mutation_values def test_headers_response_items(get_response_headers): @@ -109,12 +118,12 @@ def test_headers_response_items(get_response_headers): ref_dict = {"a": "123, 456", "b": "789"} assert set(h.items()) == set(ref_dict.items()) assert repr(h.items()) == "ItemsView({'a': '123, 456', 'b': '789'})" - assert ("a", '123, 456') in h.items() - assert not ("a", '123, 456', '123, 456') in h.items() + assert ("a", "123, 456") in h.items() + assert not ("a", "123, 456", "123, 456") in h.items() assert not {"a": "blah", "123, 456": "blah"} in h.items() - assert ("A", '123, 456') in h.items() - assert ("b", '789') in h.items() - assert ("B", '789') in h.items() + assert ("A", "123, 456") in h.items() + assert ("b", "789") in h.items() + assert ("B", "789") in h.items() assert set(h.items()) == set(ref_dict.items()) @@ -122,8 +131,9 @@ def test_headers_response_items_mutability(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers")) # test mutability before_mutation_items = h.items() - h['c'] = '000' - assert ('c', '000') in before_mutation_items + h["c"] = "000" + assert ("c", "000") in before_mutation_items + def test_header_mutations(get_request_headers, get_response_headers): def _headers_check(h): @@ -138,21 +148,29 @@ def _headers_check(h): assert dict(h) == {"a": "2", "b": "4"} del h["a"] assert dict(h) == {"b": "4"} + _headers_check(get_request_headers({})) _headers_check(get_response_headers(HttpRequest("GET", "/headers/empty"))) + def test_copy_headers_method(get_request_headers, get_response_headers): def _header_check(h): headers_copy = h.copy() assert h == headers_copy assert h is not headers_copy - _header_check(get_request_headers({ - "lowercase-header": "lowercase", - "ALLCAPS-HEADER": "ALLCAPS", - "CamelCase-Header": "camelCase", - })) + + _header_check( + get_request_headers( + { + "lowercase-header": "lowercase", + "ALLCAPS-HEADER": "ALLCAPS", + "CamelCase-Header": "camelCase", + } + ) + ) _header_check(get_response_headers(HttpRequest("GET", "/headers/case-insensitive"))) + def test_headers_insert_retains_ordering(get_request_headers, get_response_headers): def _header_check(h): h["b"] = "123" @@ -160,6 +178,7 @@ def _header_check(h): assert list(h.values()) == ["a", "123", "c"] else: assert set(h.values()) == set(["a", "123", "c"]) + _header_check(get_request_headers({"a": "a", "b": "b", "c": "c"})) _header_check(get_response_headers(HttpRequest("GET", "/headers/ordered"))) @@ -171,11 +190,16 @@ def _headers_check(h): assert list(h.values()) == ["lowercase", "ALLCAPS", "camelCase", "123"] else: assert set(list(h.values())) == set(["lowercase", "ALLCAPS", "camelCase", "123"]) - _headers_check(get_request_headers({ - "lowercase-header": "lowercase", - "ALLCAPS-HEADER": "ALLCAPS", - "CamelCase-Header": "camelCase", - })) + + _headers_check( + get_request_headers( + { + "lowercase-header": "lowercase", + "ALLCAPS-HEADER": "ALLCAPS", + "CamelCase-Header": "camelCase", + } + ) + ) _headers_check(get_response_headers(HttpRequest("GET", "/headers/case-insensitive"))) @@ -183,6 +207,7 @@ def test_headers_insert_removes_all_existing(get_request_headers, get_response_h def _headers_check(h): h["a"] = "789" assert dict(h) == {"a": "789", "b": "789"} + _headers_check(get_request_headers([("a", "123"), ("a", "456"), ("b", "789")])) _headers_check(get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers"))) @@ -191,60 +216,65 @@ def test_headers_delete_removes_all_existing(get_request_headers, get_response_h def _headers_check(h): del h["a"] assert dict(h) == {"b": "789"} + _headers_check(get_request_headers([("a", "123"), ("a", "456"), ("b", "789")])) _headers_check(get_response_headers(HttpRequest("GET", "/headers/duplicate/numbers"))) + def test_headers_not_override(): - request = HttpRequest("PUT", "http://example.org", json={"hello": "world"}, headers={"Content-Length": "5000", "Content-Type": "application/my-content-type"}) + request = HttpRequest( + "PUT", + "http://example.org", + json={"hello": "world"}, + headers={"Content-Length": "5000", "Content-Type": "application/my-content-type"}, + ) assert request.headers["Content-Length"] == "5000" assert request.headers["Content-Type"] == "application/my-content-type" + def test_headers_case_insensitive(get_request_headers, get_response_headers): def _headers_check(h): assert ( - h["lowercase-header"] == - h["LOWERCASE-HEADER"] == - h["Lowercase-Header"] == - h["lOwErCasE-HeADer"] == - "lowercase" + h["lowercase-header"] + == h["LOWERCASE-HEADER"] + == h["Lowercase-Header"] + == h["lOwErCasE-HeADer"] + == "lowercase" ) + assert h["allcaps-header"] == h["ALLCAPS-HEADER"] == h["Allcaps-Header"] == h["AlLCapS-HeADer"] == "ALLCAPS" assert ( - h["allcaps-header"] == - h["ALLCAPS-HEADER"] == - h["Allcaps-Header"] == - h["AlLCapS-HeADer"] == - "ALLCAPS" + h["camelcase-header"] + == h["CAMELCASE-HEADER"] + == h["CamelCase-Header"] + == h["cAMeLCaSE-hEadER"] + == "camelCase" ) - assert ( - h["camelcase-header"] == - h["CAMELCASE-HEADER"] == - h["CamelCase-Header"] == - h["cAMeLCaSE-hEadER"] == - "camelCase" + + _headers_check( + get_request_headers( + { + "lowercase-header": "lowercase", + "ALLCAPS-HEADER": "ALLCAPS", + "CamelCase-Header": "camelCase", + } ) - _headers_check(get_request_headers({ - "lowercase-header": "lowercase", - "ALLCAPS-HEADER": "ALLCAPS", - "CamelCase-Header": "camelCase", - })) + ) _headers_check(get_response_headers(HttpRequest("GET", "/headers/case-insensitive"))) + def test_multiple_headers_duplicate_case_insensitive(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/case-insensitive")) - assert ( - h["Duplicate-Header"] == - h['duplicate-header'] == - h['DupLicAte-HeaDER'] == - "one, two, three" - ) + assert h["Duplicate-Header"] == h["duplicate-header"] == h["DupLicAte-HeaDER"] == "one, two, three" + def test_multiple_headers_commas(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/commas")) assert h["Set-Cookie"] == "a, b, c" + def test_update(get_response_headers): h = get_response_headers(HttpRequest("GET", "/headers/duplicate/commas")) assert h["Set-Cookie"] == "a, b, c" h.update({"Set-Cookie": "override", "new-key": "new-value"}) - assert h['Set-Cookie'] == 'override' + assert h["Set-Cookie"] == "override" assert h["new-key"] == "new-value" diff --git a/sdk/core/azure-core/tests/test_rest_http_request.py b/sdk/core/azure-core/tests/test_rest_http_request.py index 7cbeba563407..34bc046d7263 100644 --- a/sdk/core/azure-core/tests/test_rest_http_request.py +++ b/sdk/core/azure-core/tests/test_rest_http_request.py @@ -11,6 +11,7 @@ import pytest import sys import os + try: import collections.abc as collections except ImportError: @@ -18,28 +19,31 @@ from azure.core.configuration import Configuration from azure.core.rest import HttpRequest -from azure.core.pipeline.policies import ( - CustomHookPolicy, UserAgentPolicy, SansIOHTTPPolicy, RetryPolicy -) +from azure.core.pipeline.policies import CustomHookPolicy, UserAgentPolicy, SansIOHTTPPolicy, RetryPolicy from azure.core.pipeline._tools import is_rest from rest_client import TestRestClient from azure.core import PipelineClient + @pytest.fixture def assert_iterator_body(): def _comparer(request, final_value): content = b"".join([p for p in request.content]) assert content == final_value + return _comparer + def test_request_repr(): request = HttpRequest("GET", "http://example.org") assert repr(request) == "" + def test_no_content(): request = HttpRequest("GET", "http://example.org") assert "Content-Length" not in request.headers + def test_content_length_header(): request = HttpRequest("POST", "http://example.org", content=b"test 123") assert request.headers["Content-Length"] == "8" @@ -69,9 +73,7 @@ def content(): yield b"test 123" # pragma: nocover headers = {"Content-Length": "8"} - request = HttpRequest( - "POST", "http://example.org", content=content(), headers=headers - ) + request = HttpRequest("POST", "http://example.org", content=content(), headers=headers) assert request.headers == {"Content-Length": "8"} assert_iterator_body(request, b"test 123") @@ -80,7 +82,9 @@ def test_url_encoded_data(): request = HttpRequest("POST", "http://example.org", data={"test": "123"}) assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" - assert request.content == {'test': '123'} # httpx makes this just b'test=123'. set_formdata_body is still keeping it as a dict + assert request.content == { + "test": "123" + } # httpx makes this just b'test=123'. set_formdata_body is still keeping it as a dict def test_json_encoded_data(): @@ -115,17 +119,22 @@ def streaming_body(data): assert "Transfer-Encoding" not in request.headers assert request.headers["Content-Length"] == "4" + def test_override_accept_encoding_header(): headers = {"Accept-Encoding": "identity"} request = HttpRequest("GET", "http://example.org", headers=headers) assert request.headers["Accept-Encoding"] == "identity" + """Test request body""" + + def test_empty_content(): request = HttpRequest("GET", "http://example.org") assert request.content is None + def test_string_content(): request = HttpRequest("PUT", "http://example.org", content="Hello, world!") assert request.headers == {"Content-Length": "13", "Content-Type": "text/plain"} @@ -163,6 +172,7 @@ def test_bytes_content(): assert request.headers == {"Content-Length": "13"} assert request.content == b"Hello, world!" + def test_iterator_content(assert_iterator_body): # NOTE: in httpx, content reads out the actual value. Don't do that (yet) in azure rest def hello_world(): @@ -199,6 +209,7 @@ def test_json_content(): } assert request.content == '{"Hello": "world!"}' + def test_urlencoded_content(): # NOTE: not adding content length setting and content testing bc we're not adding content length in the rest code # that's dealt with later in the pipeline. @@ -207,6 +218,7 @@ def test_urlencoded_content(): "Content-Type": "application/x-www-form-urlencoded", } + @pytest.mark.parametrize(("key"), (1, 2.3, None)) def test_multipart_invalid_key(key): @@ -237,26 +249,28 @@ def test_multipart_invalid_key_binary_string(): assert "Invalid type for data name" in str(e.value) assert repr(b"abc") in str(e.value) + def test_data_str_input(): data = { - 'scope': 'fake_scope', - u'grant_type': 'refresh_token', - 'refresh_token': u'REDACTED', - 'service': 'fake_url.azurecr.io' + "scope": "fake_scope", + "grant_type": "refresh_token", + "refresh_token": "REDACTED", + "service": "fake_url.azurecr.io", } request = HttpRequest("POST", "http://localhost:3000/", data=data) assert len(request.content) == 4 assert request.content["scope"] == "fake_scope" assert request.content["grant_type"] == "refresh_token" - assert request.content["refresh_token"] == u"REDACTED" + assert request.content["refresh_token"] == "REDACTED" assert request.content["service"] == "fake_url.azurecr.io" assert len(request.headers) == 1 - assert request.headers['Content-Type'] == 'application/x-www-form-urlencoded' + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + def test_content_str_input(): requests = [ HttpRequest("POST", "/fake", content="hello, world!"), - HttpRequest("POST", "/fake", content=u"hello, world!"), + HttpRequest("POST", "/fake", content="hello, world!"), ] for request in requests: assert len(request.headers) == 2 @@ -264,6 +278,7 @@ def test_content_str_input(): assert request.headers["Content-Length"] == "13" assert request.content == "hello, world!" + @pytest.mark.parametrize(("value"), (object(), {"key": "value"})) def test_multipart_invalid_value(value): @@ -273,11 +288,13 @@ def test_multipart_invalid_value(value): HttpRequest("POST", "http://localhost:8000/", data=data, files=files) assert "Invalid type for data value" in str(e.value) + def test_empty_request(): request = HttpRequest("POST", url="http://example.org", data={}, files={}) assert request.headers == {} - assert not request.content # in core, we don't convert urlencoded dict to bytes representation in content + assert not request.content # in core, we don't convert urlencoded dict to bytes representation in content + def test_read_content(assert_iterator_body): def content(): @@ -288,22 +305,23 @@ def content(): # in this case, request._data is what we end up passing to the requests transport assert isinstance(request._data, collections.Iterable) + def test_complicated_json(client): # thanks to Sean Kane for this test! input = { - 'EmptyByte': '', - 'EmptyUnicode': '', - 'SpacesOnlyByte': ' ', - 'SpacesOnlyUnicode': ' ', - 'SpacesBeforeByte': ' Text', - 'SpacesBeforeUnicode': ' Text', - 'SpacesAfterByte': 'Text ', - 'SpacesAfterUnicode': 'Text ', - 'SpacesBeforeAndAfterByte': ' Text ', - 'SpacesBeforeAndAfterUnicode': ' Text ', - '啊齄丂狛': 'ꀕ', - 'RowKey': 'test2', - '啊齄丂狛狜': 'hello', + "EmptyByte": "", + "EmptyUnicode": "", + "SpacesOnlyByte": " ", + "SpacesOnlyUnicode": " ", + "SpacesBeforeByte": " Text", + "SpacesBeforeUnicode": " Text", + "SpacesAfterByte": "Text ", + "SpacesAfterUnicode": "Text ", + "SpacesBeforeAndAfterByte": " Text ", + "SpacesBeforeAndAfterUnicode": " Text ", + "啊齄丂狛": "ꀕ", + "RowKey": "test2", + "啊齄丂狛狜": "hello", "singlequote": "a''''b", "doublequote": 'a""""b', "None": None, @@ -312,25 +330,26 @@ def test_complicated_json(client): r = client.send_request(request) r.raise_for_status() + def test_use_custom_json_encoder(): # this is to test we're using azure.core.serialization.AzureJSONEncoder # to serialize our JSON objects # since json can't serialize bytes by default but AzureJSONEncoder can, # we pass in bytes and check that they are serialized request = HttpRequest("GET", "/headers", json=bytearray("mybytes", "utf-8")) - assert request.content == '"bXlieXRlcw=="' # cspell:disable-line + assert request.content == '"bXlieXRlcw=="' # cspell:disable-line + def test_request_policies_raw_request_hook(port): # test that the request all the way through the pipeline is a new request request = HttpRequest("GET", "/headers") + def callback(request): assert is_rest(request.http_request) raise ValueError("I entered the callback!") + custom_hook_policy = CustomHookPolicy(raw_request_hook=callback) - policies = [ - UserAgentPolicy("myuseragent"), - custom_hook_policy - ] + policies = [UserAgentPolicy("myuseragent"), custom_hook_policy] client = TestRestClient(port=port, policies=policies) with pytest.raises(ValueError) as ex: @@ -341,7 +360,7 @@ def callback(request): def test_request_policies_chain(port): class OldPolicyModifyBody(SansIOHTTPPolicy): def on_request(self, request): - assert is_rest(request.http_request) # first make sure this is a new request + assert is_rest(request.http_request) # first make sure this is a new request # deals with request like an old request request.http_request.set_json_body({"hello": "world"}) @@ -353,7 +372,7 @@ def on_request(self, request): # modify header to know we entered this callback request.http_request.headers = { "x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT", - "Authorization": "SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=", # fake key suppressed in credscan + "Authorization": "SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=", # fake key suppressed in credscan "Content-Length": "0", } @@ -363,11 +382,11 @@ def on_request(self, request): # don't want to deal with content in serialize, so let's first just remove it request.http_request.data = None expected = ( - b'DELETE http://localhost:5000/container0/blob0 HTTP/1.1\r\n' - b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' - b'Authorization: SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=\r\n' # fake key suppressed in credscan - b'Content-Length: 0\r\n' - b'\r\n' + b"DELETE http://localhost:5000/container0/blob0 HTTP/1.1\r\n" + b"x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n" + b"Authorization: SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=\r\n" # fake key suppressed in credscan + b"Content-Length: 0\r\n" + b"\r\n" ) assert request.http_request.serialize() == expected raise ValueError("Passed through the policies!") @@ -403,17 +422,14 @@ def on_request(self, pipeline_request): return pipeline_request class NewPolicy(SansIOHTTPPolicy): - def on_request(self, pipeline_request): request = pipeline_request.http_request assert is_rest(request) - assert request.content == 'change to me!' # new request has property content + assert request.content == "change to me!" # new request has property content raise ValueError("I entered the policies!") pipeline_client = PipelineClient( - base_url="http://localhost:{}".format(port), - config=config, - per_call_policies=[OldPolicy(), NewPolicy()] + base_url="http://localhost:{}".format(port), config=config, per_call_policies=[OldPolicy(), NewPolicy()] ) client = TestRestClient(port=port) client._client = pipeline_client @@ -426,8 +442,9 @@ def on_request(self, pipeline_request): # work assert "I entered the policies!" in str(ex.value) + def test_json_file_valid(): - json_bytes = bytearray('{"more": "cowbell"}', encoding='utf-8') + json_bytes = bytearray('{"more": "cowbell"}', encoding="utf-8") with io.BytesIO(json_bytes) as json_file: request = HttpRequest("PUT", "/fake", json=json_file) assert request.headers == {"Content-Type": "application/json"} @@ -435,8 +452,9 @@ def test_json_file_valid(): assert not request.content.closed assert request.content.read() == b'{"more": "cowbell"}' + def test_json_file_invalid(): - json_bytes = bytearray('{"more": "cowbell" i am not valid', encoding='utf-8') + json_bytes = bytearray('{"more": "cowbell" i am not valid', encoding="utf-8") with io.BytesIO(json_bytes) as json_file: request = HttpRequest("PUT", "/fake", json=json_file) assert request.headers == {"Content-Type": "application/json"} @@ -444,8 +462,9 @@ def test_json_file_invalid(): assert not request.content.closed assert request.content.read() == b'{"more": "cowbell" i am not valid' + def test_json_file_content_type_input(): - json_bytes = bytearray('{"more": "cowbell"}', encoding='utf-8') + json_bytes = bytearray('{"more": "cowbell"}', encoding="utf-8") with io.BytesIO(json_bytes) as json_file: request = HttpRequest("PUT", "/fake", json=json_file, headers={"Content-Type": "application/json-special"}) assert request.headers == {"Content-Type": "application/json-special"} @@ -453,6 +472,7 @@ def test_json_file_content_type_input(): assert not request.content.closed assert request.content.read() == b'{"more": "cowbell"}' + class NonSeekableStream: def __init__(self, wrapped_stream): self.wrapped_stream = wrapped_stream @@ -469,13 +489,15 @@ def seek(self, *args, **kwargs): def tell(self): return self.wrapped_stream.tell() + def test_non_seekable_stream_input(): data = b"a" * 4 * 1024 data_stream = NonSeekableStream(io.BytesIO(data)) - HttpRequest(method="PUT", url="http://www.example.com", content=data_stream) # ensure we can make this HttpRequest + HttpRequest(method="PUT", url="http://www.example.com", content=data_stream) # ensure we can make this HttpRequest + class Stream: - def __init__(self, length, initial_buffer_length=4*1024): + def __init__(self, length, initial_buffer_length=4 * 1024): self._base_data = os.urandom(initial_buffer_length) self._base_data_length = initial_buffer_length self._position = 0 @@ -487,9 +509,10 @@ def read(self, size=None): def remaining(self): return self._remaining + def test_stream_input(): data_stream = Stream(length=4) - HttpRequest(method="PUT", url="http://www.example.com", content=data_stream) # ensure we can make this HttpRequest + HttpRequest(method="PUT", url="http://www.example.com", content=data_stream) # ensure we can make this HttpRequest # NOTE: For files, we don't allow list of tuples yet, just dict. Will uncomment when we add this capability diff --git a/sdk/core/azure-core/tests/test_rest_http_response.py b/sdk/core/azure-core/tests/test_rest_http_response.py index e0a6cd0d6839..0f3219517103 100644 --- a/sdk/core/azure-core/tests/test_rest_http_response.py +++ b/sdk/core/azure-core/tests/test_rest_http_response.py @@ -17,14 +17,17 @@ import xml.etree.ElementTree as ET from utils import readonly_checks + @pytest.fixture def send_request(client): def _send_request(request): response = client.send_request(request, stream=False) response.raise_for_status() return response + return _send_request + def test_response(send_request, port): response = send_request( request=HttpRequest("GET", "/basic/string"), @@ -52,10 +55,11 @@ def test_response_text(send_request): assert response.status_code == 200 assert response.reason == "OK" assert response.text() == "Hello, world!" - assert response.headers["Content-Length"] == '13' - assert response.headers['Content-Type'] == "text/plain; charset=utf-8" + assert response.headers["Content-Length"] == "13" + assert response.headers["Content-Type"] == "text/plain; charset=utf-8" assert response.content_type == "text/plain; charset=utf-8" + def test_response_html(send_request): response = send_request( request=HttpRequest("GET", "/basic/html"), @@ -64,6 +68,7 @@ def test_response_html(send_request): assert response.reason == "OK" assert response.text() == "Hello, world!" + def test_raise_for_status(client): response = client.send_request( HttpRequest("GET", "/basic/string"), @@ -85,21 +90,19 @@ def test_raise_for_status(client): with pytest.raises(HttpResponseError): response.raise_for_status() + def test_response_repr(send_request): - response = send_request( - request=HttpRequest("GET", "/basic/string") - ) + response = send_request(request=HttpRequest("GET", "/basic/string")) assert repr(response) == "" + def test_response_content_type_encoding(send_request): """ Use the charset encoding in the Content-Type header if possible. """ - response = send_request( - request=HttpRequest("GET", "/encoding/latin-1") - ) + response = send_request(request=HttpRequest("GET", "/encoding/latin-1")) assert response.content_type == "text/plain; charset=latin-1" - assert response.text() == u"Latin 1: ÿ" + assert response.text() == "Latin 1: ÿ" assert response.encoding == "latin-1" @@ -107,11 +110,9 @@ def test_response_autodetect_encoding(send_request): """ Autodetect encoding if there is no Content-Type header. """ - response = send_request( - request=HttpRequest("GET", "/encoding/latin-1") - ) + response = send_request(request=HttpRequest("GET", "/encoding/latin-1")) - assert response.text() == u'Latin 1: ÿ' + assert response.text() == "Latin 1: ÿ" assert response.encoding == "latin-1" @@ -119,12 +120,10 @@ def test_response_fallback_to_autodetect(send_request): """ Fallback to autodetection if we get an invalid charset in the Content-Type header. """ - response = send_request( - request=HttpRequest("GET", "/encoding/invalid-codec-name") - ) + response = send_request(request=HttpRequest("GET", "/encoding/invalid-codec-name")) assert response.headers["Content-Type"] == "text/plain; charset=invalid-codec-name" - assert response.text() == u"おはようございます。" + assert response.text() == "おはようございます。" assert response.encoding is None @@ -151,7 +150,7 @@ def test_response_no_charset_with_iso_8859_1_content(send_request): response = send_request( request=HttpRequest("GET", "/encoding/iso-8859-1"), ) - assert response.text() == u"Accented: �sterreich" + assert response.text() == "Accented: �sterreich" assert response.encoding is None @@ -162,6 +161,7 @@ def test_json(send_request): assert response.json() == {"greeting": "hello", "recipient": "world"} assert response.encoding is None + def test_json_with_specified_encoding(send_request): response = send_request( request=HttpRequest("GET", "/encoding/json"), @@ -169,33 +169,36 @@ def test_json_with_specified_encoding(send_request): assert response.json() == {"greeting": "hello", "recipient": "world"} assert response.encoding == "utf-16" + def test_emoji(send_request): response = send_request( request=HttpRequest("GET", "/encoding/emoji"), ) - assert response.text() == u"👩" + assert response.text() == "👩" + def test_emoji_family_with_skin_tone_modifier(send_request): response = send_request( request=HttpRequest("GET", "/encoding/emoji-family-skin-tone-modifier"), ) - assert response.text() == u"👩🏻‍👩🏽‍👧🏾‍👦🏿 SSN: 859-98-0987" + assert response.text() == "👩🏻‍👩🏽‍👧🏾‍👦🏿 SSN: 859-98-0987" + def test_korean_nfc(send_request): response = send_request( request=HttpRequest("GET", "/encoding/korean"), ) - assert response.text() == u"아가" + assert response.text() == "아가" + def test_urlencoded_content(send_request): send_request( request=HttpRequest( - "POST", - "/urlencoded/pet/add/1", - data={ "pet_type": "dog", "pet_food": "meat", "name": "Fido", "pet_age": 42 } + "POST", "/urlencoded/pet/add/1", data={"pet_type": "dog", "pet_food": "meat", "name": "Fido", "pet_age": 42} ), ) + def test_multipart_files_content(send_request): request = HttpRequest( "POST", @@ -204,6 +207,7 @@ def test_multipart_files_content(send_request): ) send_request(request) + def test_multipart_data_and_files_content(send_request): request = HttpRequest( "POST", @@ -240,6 +244,7 @@ def data(): ) send_request(request) + def test_get_xml_basic(send_request): request = HttpRequest( "GET", @@ -247,11 +252,12 @@ def test_get_xml_basic(send_request): ) response = send_request(request) parsed_xml = ET.fromstring(response.text()) - assert parsed_xml.tag == 'slideshow' + assert parsed_xml.tag == "slideshow" attributes = parsed_xml.attrib - assert attributes['title'] == "Sample Slide Show" - assert attributes['date'] == "Date of publication" - assert attributes['author'] == "Yours Truly" + assert attributes["title"] == "Sample Slide Show" + assert attributes["date"] == "Date of publication" + assert attributes["author"] == "Yours Truly" + def test_put_xml_basic(send_request): @@ -278,6 +284,7 @@ def test_put_xml_basic(send_request): ) send_request(request) + def test_send_request_return_pipeline_response(client): # we use return_pipeline_response for some cases in autorest request = HttpRequest("GET", "/basic/string") @@ -288,43 +295,48 @@ def test_send_request_return_pipeline_response(client): assert response.http_response.text() == "Hello, world!" assert hasattr(response.http_request, "content") + def test_text_and_encoding(send_request): response = send_request( request=HttpRequest("GET", "/encoding/emoji"), ) - assert response.content == u"👩".encode("utf-8") - assert response.text() == u"👩" + assert response.content == "👩".encode("utf-8") + assert response.text() == "👩" # try setting encoding as a property response.encoding = "utf-16" - assert response.text() == u"鿰ꦑ" == response.content.decode(response.encoding) + assert response.text() == "鿰ꦑ" == response.content.decode(response.encoding) # assert latin-1 changes text decoding without changing encoding property - assert response.text("latin-1") == u'ð\x9f\x91©' == response.content.decode("latin-1") + assert response.text("latin-1") == "ð\x9f\x91©" == response.content.decode("latin-1") assert response.encoding == "utf-16" + def test_passing_encoding_to_text(send_request): response = send_request( request=HttpRequest("GET", "/encoding/emoji"), ) - assert response.content == u"👩".encode("utf-8") - assert response.text() == u"👩" + assert response.content == "👩".encode("utf-8") + assert response.text() == "👩" # pass in different encoding - assert response.text("latin-1") == u'ð\x9f\x91©' + assert response.text("latin-1") == "ð\x9f\x91©" # check response.text() still gets us the old value - assert response.text() == u"👩" + assert response.text() == "👩" + def test_initialize_response_abc(): with pytest.raises(TypeError) as ex: HttpResponse() assert "Can't instantiate abstract class" in str(ex) + def test_readonly(send_request): """Make sure everything that is readonly is readonly""" response = send_request(HttpRequest("GET", "/health")) assert isinstance(response, RestRequestsTransportResponse) from azure.core.pipeline.transport import RequestsTransportResponse + readonly_checks(response, old_response_class=RequestsTransportResponse) diff --git a/sdk/core/azure-core/tests/test_rest_polling.py b/sdk/core/azure-core/tests/test_rest_polling.py index 3dfebef958db..fad744dee647 100644 --- a/sdk/core/azure-core/tests/test_rest_polling.py +++ b/sdk/core/azure-core/tests/test_rest_polling.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,89 +22,104 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import pytest from azure.core.exceptions import ServiceRequestError from azure.core.rest import HttpRequest from azure.core.polling import LROPoller from azure.core.polling.base_polling import LROBasePolling + @pytest.fixture def deserialization_callback(): def _callback(response): return response.http_response.json() + return _callback + @pytest.fixture def lro_poller(client, deserialization_callback): def _callback(request, **kwargs): - initial_response = client.send_request( - request=request, - _return_pipeline_response=True - ) + initial_response = client.send_request(request=request, _return_pipeline_response=True) return LROPoller( client._client, initial_response, deserialization_callback, LROBasePolling(0, **kwargs), ) + return _callback + def test_post_with_location_and_operation_location_headers(lro_poller): poller = lro_poller(HttpRequest("POST", "/polling/post/location-and-operation-location")) result = poller.result() - assert result == {'location_result': True} + assert result == {"location_result": True} + def test_post_with_location_and_operation_location_headers_no_body(lro_poller): poller = lro_poller(HttpRequest("POST", "/polling/post/location-and-operation-location-no-body")) result = poller.result() assert result is None + def test_post_resource_location(lro_poller): poller = lro_poller(HttpRequest("POST", "/polling/post/resource-location")) result = poller.result() - assert result == {'location_result': True} + assert result == {"location_result": True} + def test_put_no_polling(lro_poller): result = lro_poller(HttpRequest("PUT", "/polling/no-polling")).result() - assert result['properties']['provisioningState'] == 'Succeeded' + assert result["properties"]["provisioningState"] == "Succeeded" + def test_put_location(lro_poller): result = lro_poller(HttpRequest("PUT", "/polling/location")).result() - assert result['location_result'] + assert result["location_result"] + def test_put_initial_response_body_invalid(lro_poller): # initial body is invalid result = lro_poller(HttpRequest("PUT", "/polling/initial-body-invalid")).result() - assert result['location_result'] + assert result["location_result"] + def test_put_operation_location_polling_fail(lro_poller): with pytest.raises(ServiceRequestError): lro_poller(HttpRequest("PUT", "/polling/bad-operation-location"), retry_total=0).result() + def test_put_location_polling_fail(lro_poller): with pytest.raises(ServiceRequestError): lro_poller(HttpRequest("PUT", "/polling/bad-location"), retry_total=0).result() + def test_patch_location(lro_poller): result = lro_poller(HttpRequest("PATCH", "/polling/location")).result() - assert result['location_result'] + assert result["location_result"] + def test_patch_operation_location_polling_fail(lro_poller): with pytest.raises(ServiceRequestError): lro_poller(HttpRequest("PUT", "/polling/bad-operation-location"), retry_total=0).result() + def test_patch_location_polling_fail(lro_poller): with pytest.raises(ServiceRequestError): lro_poller(HttpRequest("PUT", "/polling/bad-location"), retry_total=0).result() + def test_delete_operation_location(lro_poller): result = lro_poller(HttpRequest("DELETE", "/polling/operation-location")).result() - assert result['status'] == 'Succeeded' + assert result["status"] == "Succeeded" + def test_request_id(lro_poller): result = lro_poller(HttpRequest("POST", "/polling/request-id"), request_id="123456789").result() + def test_continuation_token(client, lro_poller, deserialization_callback): poller = lro_poller(HttpRequest("POST", "/polling/post/location-and-operation-location")) token = poller.continuation_token() @@ -115,4 +130,4 @@ def test_continuation_token(client, lro_poller, deserialization_callback): deserialization_callback=deserialization_callback, ) result = new_poller.result() - assert result == {'location_result': True} + assert result == {"location_result": True} diff --git a/sdk/core/azure-core/tests/test_rest_query.py b/sdk/core/azure-core/tests/test_rest_query.py index 7aeda360da6a..e35715ee5e5e 100644 --- a/sdk/core/azure-core/tests/test_rest_query.py +++ b/sdk/core/azure-core/tests/test_rest_query.py @@ -10,21 +10,26 @@ import pytest from azure.core.rest import HttpRequest + def _format_query_into_url(url, params): request = HttpRequest(method="GET", url=url, params=params) return request.url + def test_request_url_with_params(): url = _format_query_into_url(url="a/b/c?t=y", params={"g": "h"}) assert url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + def test_request_url_with_params_as_list(): - url = _format_query_into_url(url="a/b/c?t=y", params={"g": ["h","i"]}) + url = _format_query_into_url(url="a/b/c?t=y", params={"g": ["h", "i"]}) assert url in ["a/b/c?g=h&g=i&t=y", "a/b/c?t=y&g=h&g=i"] + def test_request_url_with_params_with_none_in_list(): with pytest.raises(ValueError): - _format_query_into_url(url="a/b/c?t=y", params={"g": ["h",None]}) + _format_query_into_url(url="a/b/c?t=y", params={"g": ["h", None]}) + def test_request_url_with_params_with_none(): with pytest.raises(ValueError): diff --git a/sdk/core/azure-core/tests/test_rest_request_backcompat.py b/sdk/core/azure-core/tests/test_rest_request_backcompat.py index e2d4f866f470..17751f7625b3 100644 --- a/sdk/core/azure-core/tests/test_rest_request_backcompat.py +++ b/sdk/core/azure-core/tests/test_rest_request_backcompat.py @@ -10,25 +10,30 @@ import xml.etree.ElementTree as ET from azure.core.pipeline.transport import HttpRequest as PipelineTransportHttpRequest from azure.core.rest import HttpRequest as RestHttpRequest + try: import collections.abc as collections except ImportError: import collections + @pytest.fixture def old_request(): return PipelineTransportHttpRequest("GET", "/") + @pytest.fixture def new_request(): return RestHttpRequest("GET", "/") + def test_request_attr_parity(old_request, new_request): for attr in dir(old_request): if not attr[0] == "_": # if not a private attr, we want parity assert hasattr(new_request, attr) + def test_request_set_attrs(old_request, new_request): for attr in dir(old_request): if attr[0] == "_": @@ -43,65 +48,56 @@ def test_request_set_attrs(old_request, new_request): setattr(new_request, attr, "foo") assert getattr(old_request, attr) == getattr(new_request, attr) == "foo" + def test_request_multipart_mixed_info(old_request, new_request): old_request.multipart_mixed_info = "foo" new_request.multipart_mixed_info = "foo" assert old_request.multipart_mixed_info == new_request.multipart_mixed_info == "foo" + def test_request_files_attr(old_request, new_request): assert old_request.files == new_request.files == None old_request.files = {"hello": "world"} new_request.files = {"hello": "world"} assert old_request.files == new_request.files == {"hello": "world"} + def test_request_data_attr(old_request, new_request): assert old_request.data == new_request.data == None old_request.data = {"hello": "world"} new_request.data = {"hello": "world"} assert old_request.data == new_request.data == {"hello": "world"} + def test_request_query(old_request, new_request): assert old_request.query == new_request.query == {} old_request.url = "http://localhost:5000?a=b&c=d" new_request.url = "http://localhost:5000?a=b&c=d" - assert old_request.query == new_request.query == {'a': 'b', 'c': 'd'} + assert old_request.query == new_request.query == {"a": "b", "c": "d"} + def test_request_query_and_params_kwarg(old_request): # should be same behavior if we pass in query params through the params kwarg in the new requests old_request.url = "http://localhost:5000?a=b&c=d" - new_request = RestHttpRequest("GET", "http://localhost:5000", params={'a': 'b', 'c': 'd'}) - assert old_request.query == new_request.query == {'a': 'b', 'c': 'd'} + new_request = RestHttpRequest("GET", "http://localhost:5000", params={"a": "b", "c": "d"}) + assert old_request.query == new_request.query == {"a": "b", "c": "d"} + def test_request_body(old_request, new_request): assert old_request.body == new_request.body == None old_request.data = {"hello": "world"} new_request.data = {"hello": "world"} - assert ( - old_request.body == - new_request.body == - new_request.content == - {"hello": "world"} - ) + assert old_request.body == new_request.body == new_request.content == {"hello": "world"} # files will not override data old_request.files = {"foo": "bar"} new_request.files = {"foo": "bar"} - assert ( - old_request.body == - new_request.body == - new_request.content == - {"hello": "world"} - ) + assert old_request.body == new_request.body == new_request.content == {"hello": "world"} # nullify data old_request.data = None new_request.data = None - assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - None - ) + assert old_request.data == new_request.data == old_request.body == new_request.body == None + def test_format_parameters(old_request, new_request): old_request.url = "a/b/c?t=y" @@ -114,14 +110,13 @@ def test_format_parameters(old_request, new_request): assert old_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] assert new_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + def test_request_format_parameters_and_params_kwarg(old_request): # calling format_parameters on an old request should be the same # behavior as passing in params to new request old_request.url = "a/b/c?t=y" old_request.format_parameters({"g": "h"}) - new_request = RestHttpRequest( - "GET", "a/b/c?t=y", params={"g": "h"} - ) + new_request = RestHttpRequest("GET", "a/b/c?t=y", params={"g": "h"}) assert old_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] assert new_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] @@ -130,6 +125,7 @@ def test_request_format_parameters_and_params_kwarg(old_request): assert new_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] assert new_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + def test_request_streamed_data_body(old_request, new_request): assert old_request.files == new_request.files == None assert old_request.data == new_request.data == None @@ -137,6 +133,7 @@ def test_request_streamed_data_body(old_request, new_request): # passing in iterable def streaming_body(data): yield data # pragma: nocover + old_request.set_streamed_data_body(streaming_body("i will be streamed")) new_request.set_streamed_data_body(streaming_body("i will be streamed")) @@ -148,6 +145,7 @@ def streaming_body(data): assert isinstance(new_request.content, collections.Iterable) assert old_request.headers == new_request.headers == {} + def test_request_streamed_data_body_non_iterable(old_request, new_request): # should fail before nullifying the files property old_request.files = new_request.files = "foo" @@ -165,11 +163,13 @@ def test_request_streamed_data_body_non_iterable(old_request, new_request): assert old_request.files == "foo" assert old_request.headers == new_request.headers == {} + def test_request_streamed_data_body_and_content_kwarg(old_request): # passing stream bodies to set_streamed_data_body # and passing a stream body to the content kwarg of the new request should be the same def streaming_body(data): yield data # pragma: nocover + old_request.set_streamed_data_body(streaming_body("stream")) new_request = RestHttpRequest("GET", "/", content=streaming_body("stream")) assert old_request.files == new_request.files == None @@ -180,6 +180,7 @@ def streaming_body(data): assert isinstance(new_request.content, collections.Iterable) assert old_request.headers == new_request.headers == {} + def test_request_text_body(old_request, new_request): assert old_request.files == new_request.files == None assert old_request.data == new_request.data == None @@ -189,31 +190,33 @@ def test_request_text_body(old_request, new_request): assert old_request.files == new_request.files == None assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - "i am text" + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == "i am text" ) - assert old_request.headers['Content-Length'] == new_request.headers['Content-Length'] == '9' + assert old_request.headers["Content-Length"] == new_request.headers["Content-Length"] == "9" assert not old_request.headers.get("Content-Type") assert new_request.headers["Content-Type"] == "text/plain" + def test_request_text_body_and_content_kwarg(old_request): old_request.set_text_body("i am text") new_request = RestHttpRequest("GET", "/", content="i am text") assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - "i am text" + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == "i am text" ) assert old_request.headers["Content-Length"] == new_request.headers["Content-Length"] == "9" assert old_request.files == new_request.files == None + def test_request_xml_body(old_request, new_request): assert old_request.files == new_request.files == None assert old_request.data == new_request.data == None @@ -224,29 +227,31 @@ def test_request_xml_body(old_request, new_request): assert old_request.files == new_request.files == None assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - b"\n" + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == b"\n" ) - assert old_request.headers == new_request.headers == {'Content-Length': '47'} + assert old_request.headers == new_request.headers == {"Content-Length": "47"} + def test_request_xml_body_and_content_kwarg(old_request): old_request.set_text_body("i am text") new_request = RestHttpRequest("GET", "/", content="i am text") assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - "i am text" + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == "i am text" ) assert old_request.headers["Content-Length"] == new_request.headers["Content-Length"] == "9" assert old_request.files == new_request.files == None + def test_request_json_body(old_request, new_request): assert old_request.files == new_request.files == None assert old_request.data == new_request.data == None @@ -257,34 +262,36 @@ def test_request_json_body(old_request, new_request): assert old_request.files == new_request.files == None assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - json.dumps(json_input) + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == json.dumps(json_input) ) - assert old_request.headers["Content-Length"] == new_request.headers['Content-Length'] == '18' + assert old_request.headers["Content-Length"] == new_request.headers["Content-Length"] == "18" assert not old_request.headers.get("Content-Type") assert new_request.headers["Content-Type"] == "application/json" + def test_request_json_body_and_json_kwarg(old_request): json_input = {"hello": "world"} old_request.set_json_body(json_input) new_request = RestHttpRequest("GET", "/", json=json_input) assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - json.dumps(json_input) + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == json.dumps(json_input) ) - assert old_request.headers["Content-Length"] == new_request.headers['Content-Length'] == '18' + assert old_request.headers["Content-Length"] == new_request.headers["Content-Length"] == "18" assert not old_request.headers.get("Content-Type") assert new_request.headers["Content-Type"] == "application/json" assert old_request.files == new_request.files == None + def test_request_formdata_body_files(old_request, new_request): assert old_request.files == new_request.files == None assert old_request.data == new_request.data == None @@ -297,17 +304,13 @@ def test_request_formdata_body_files(old_request, new_request): new_request.set_formdata_body({"fileName": "hello.jpg"}) assert old_request.data == new_request.data == None - assert ( - old_request.files == - new_request.files == - new_request.content == - {'fileName': (None, 'hello.jpg')} - ) + assert old_request.files == new_request.files == new_request.content == {"fileName": (None, "hello.jpg")} # we don't set any multipart headers with boundaries # we rely on the transport to boundary calculating assert old_request.headers == new_request.headers == {} + def test_request_formdata_body_data(old_request, new_request): assert old_request.files == new_request.files == None assert old_request.data == new_request.data == None @@ -323,17 +326,18 @@ def test_request_formdata_body_data(old_request, new_request): assert old_request.files == new_request.files == None assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - {"fileName": "hello.jpg"} + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == {"fileName": "hello.jpg"} ) # old behavior would pop out the Content-Type header # new behavior doesn't do that assert old_request.headers == {} - assert new_request.headers == {'Content-Type': "application/x-www-form-urlencoded"} + assert new_request.headers == {"Content-Type": "application/x-www-form-urlencoded"} + def test_request_formdata_body_and_files_kwarg(old_request): files = {"fileName": "hello.jpg"} @@ -342,7 +346,8 @@ def test_request_formdata_body_and_files_kwarg(old_request): assert old_request.data == new_request.data == None assert old_request.body == new_request.body == None assert old_request.headers == new_request.headers == {} - assert old_request.files == new_request.files == {'fileName': (None, 'hello.jpg')} + assert old_request.files == new_request.files == {"fileName": (None, "hello.jpg")} + def test_request_formdata_body_and_data_kwarg(old_request): data = {"fileName": "hello.jpg"} @@ -352,17 +357,18 @@ def test_request_formdata_body_and_data_kwarg(old_request): old_request.set_formdata_body(data) new_request = RestHttpRequest("GET", "/", data=data) assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - {"fileName": "hello.jpg"} + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == {"fileName": "hello.jpg"} ) assert old_request.headers == {} assert new_request.headers == {"Content-Type": "application/x-www-form-urlencoded"} assert old_request.files == new_request.files == None + def test_request_bytes_body(old_request, new_request): assert old_request.files == new_request.files == None assert old_request.data == new_request.data == None @@ -373,26 +379,27 @@ def test_request_bytes_body(old_request, new_request): assert old_request.files == new_request.files == None assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - bytes_input + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == bytes_input ) - assert old_request.headers == new_request.headers == {'Content-Length': '13'} + assert old_request.headers == new_request.headers == {"Content-Length": "13"} + def test_request_bytes_body_and_content_kwarg(old_request): bytes_input = b"hello, world!" old_request.set_bytes_body(bytes_input) new_request = RestHttpRequest("GET", "/", content=bytes_input) assert ( - old_request.data == - new_request.data == - old_request.body == - new_request.body == - new_request.content == - bytes_input + old_request.data + == new_request.data + == old_request.body + == new_request.body + == new_request.content + == bytes_input ) - assert old_request.headers == new_request.headers == {'Content-Length': '13'} + assert old_request.headers == new_request.headers == {"Content-Length": "13"} assert old_request.files == new_request.files diff --git a/sdk/core/azure-core/tests/test_rest_response_backcompat.py b/sdk/core/azure-core/tests/test_rest_response_backcompat.py index 2c890cb84fd7..c3a7cf9bbdf3 100644 --- a/sdk/core/azure-core/tests/test_rest_response_backcompat.py +++ b/sdk/core/azure-core/tests/test_rest_response_backcompat.py @@ -12,28 +12,34 @@ from azure.core.pipeline import Pipeline from azure.core.pipeline.transport import RequestsTransport + @pytest.fixture def old_request(port): return PipelineTransportHttpRequest("GET", "http://localhost:{}/streams/basic".format(port)) + @pytest.fixture def old_response(old_request): return RequestsTransport().send(old_request) + @pytest.fixture def new_request(port): return RestHttpRequest("GET", "http://localhost:{}/streams/basic".format(port)) + @pytest.fixture def new_response(new_request): return RequestsTransport().send(new_request) + def test_response_attr_parity(old_response, new_response): for attr in dir(old_response): if not attr[0] == "_": # if not a private attr, we want parity assert hasattr(new_response, attr) + def test_response_set_attrs(old_response, new_response): for attr in dir(old_response): if attr[0] == "_": @@ -48,21 +54,29 @@ def test_response_set_attrs(old_response, new_response): setattr(new_response, attr, "foo") assert getattr(old_response, attr) == getattr(new_response, attr) == "foo" + def test_response_block_size(old_response, new_response): assert old_response.block_size == new_response.block_size == 4096 old_response.block_size = 500 new_response.block_size = 500 assert old_response.block_size == new_response.block_size == 500 + def test_response_body(old_response, new_response): assert old_response.body() == new_response.body() == b"Hello, world!" + def test_response_internal_response(old_response, new_response, port): - assert old_response.internal_response.url == new_response.internal_response.url == "http://localhost:{}/streams/basic".format(port) + assert ( + old_response.internal_response.url + == new_response.internal_response.url + == "http://localhost:{}/streams/basic".format(port) + ) old_response.internal_response = "foo" new_response.internal_response = "foo" assert old_response.internal_response == new_response.internal_response == "foo" + def test_response_stream_download(old_request, new_request): transport = RequestsTransport() pipeline = Pipeline(transport) @@ -74,44 +88,53 @@ def test_response_stream_download(old_request, new_request): new_string = b"".join(new_response.stream_download(pipeline)) assert old_string == new_string == b"Hello, world!" + def test_response_request(old_response, new_response, port): assert old_response.request.url == new_response.request.url == "http://localhost:{}/streams/basic".format(port) old_response.request = "foo" new_response.request = "foo" assert old_response.request == new_response.request == "foo" + def test_response_status_code(old_response, new_response): assert old_response.status_code == new_response.status_code == 200 old_response.status_code = 202 new_response.status_code = 202 assert old_response.status_code == new_response.status_code == 202 + def test_response_headers(old_response, new_response): - assert set(old_response.headers.keys()) == set(new_response.headers.keys()) == set(["Content-Type", "Connection", "Server", "Date"]) + assert ( + set(old_response.headers.keys()) + == set(new_response.headers.keys()) + == set(["Content-Type", "Connection", "Server", "Date"]) + ) old_response.headers = {"Hello": "world!"} new_response.headers = {"Hello": "world!"} assert old_response.headers == new_response.headers == {"Hello": "world!"} + def test_response_reason(old_response, new_response): assert old_response.reason == new_response.reason == "OK" old_response.reason = "Not OK" new_response.reason = "Not OK" assert old_response.reason == new_response.reason == "Not OK" + def test_response_content_type(old_response, new_response): assert old_response.content_type == new_response.content_type == "text/html; charset=utf-8" old_response.content_type = "application/json" new_response.content_type = "application/json" assert old_response.content_type == new_response.content_type == "application/json" + def _create_multiapart_request(http_request_class): class ResponsePolicy(object): - def on_request(self, *args): return def on_response(self, request, response): - response.http_response.headers['x-ms-fun'] = 'true' + response.http_response.headers["x-ms-fun"] = "true" req0 = http_request_class("DELETE", "/container0/blob0") req1 = http_request_class("DELETE", "/container1/blob1") @@ -119,6 +142,7 @@ def on_response(self, request, response): request.set_multipart_mixed(req0, req1, policies=[ResponsePolicy()]) return request + def _test_parts(response): # hack the content type parts = response.parts() @@ -126,11 +150,11 @@ def _test_parts(response): parts0 = parts[0] assert parts0.status_code == 202 - assert parts0.headers['x-ms-fun'] == 'true' + assert parts0.headers["x-ms-fun"] == "true" parts1 = parts[1] assert parts1.status_code == 404 - assert parts1.headers['x-ms-fun'] == 'true' + assert parts1.headers["x-ms-fun"] == "true" def test_response_parts(port): diff --git a/sdk/core/azure-core/tests/test_rest_stream_responses.py b/sdk/core/azure-core/tests/test_rest_stream_responses.py index 645e86232ca2..7b37c9ee797a 100644 --- a/sdk/core/azure-core/tests/test_rest_stream_responses.py +++ b/sdk/core/azure-core/tests/test_rest_stream_responses.py @@ -9,19 +9,17 @@ from azure.core.exceptions import StreamClosedError, StreamConsumedError, ResponseNotReadError from azure.core.exceptions import HttpResponseError, ServiceRequestError + def _assert_stream_state(response, open): # if open is true, check the stream is open. # if false, check if everything is closed - checks = [ - response._internal_response._content_consumed, - response.is_closed, - response.is_stream_consumed - ] + checks = [response._internal_response._content_consumed, response.is_closed, response.is_stream_consumed] if open: assert not any(checks) else: assert all(checks) + def test_iter_raw(client): request = HttpRequest("GET", "/streams/basic") with client.send_request(request, stream=True) as response: @@ -36,6 +34,7 @@ def test_iter_raw(client): assert response.is_closed assert response.is_stream_consumed + def test_iter_raw_on_iterable(client): request = HttpRequest("GET", "/streams/iterable") @@ -45,6 +44,7 @@ def test_iter_raw_on_iterable(client): raw += part assert raw == b"Hello, world!" + def test_iter_with_error(client): request = HttpRequest("GET", "/errors/403") @@ -64,6 +64,7 @@ def test_iter_with_error(client): raise ValueError("Should error before entering") assert response.is_closed + def test_iter_bytes(client): request = HttpRequest("GET", "/streams/basic") @@ -79,6 +80,7 @@ def test_iter_bytes(client): assert response.is_stream_consumed assert raw == b"Hello, world!" + @pytest.mark.skip(reason="We've gotten rid of iter_text for now") def test_iter_text(client): request = HttpRequest("GET", "/basic/string") @@ -88,6 +90,7 @@ def test_iter_text(client): content += part assert content == "Hello, world!" + @pytest.mark.skip(reason="We've gotten rid of iter_lines for now") def test_iter_lines(client): request = HttpRequest("GET", "/basic/lines") @@ -98,6 +101,7 @@ def test_iter_lines(client): content.append(line) assert content == ["Hello,\n", "world!"] + def test_sync_streaming_response(client): request = HttpRequest("GET", "/streams/basic") @@ -111,6 +115,7 @@ def test_sync_streaming_response(client): assert response.content == b"Hello, world!" assert response.is_closed + def test_cannot_read_after_stream_consumed(client, port): request = HttpRequest("GET", "/streams/basic") @@ -127,6 +132,7 @@ def test_cannot_read_after_stream_consumed(client, port): assert "".format(port) in str(ex.value) assert "You have likely already consumed this stream, so it can not be accessed anymore" in str(ex.value) + def test_cannot_read_after_response_closed(port, client): request = HttpRequest("GET", "/streams/basic") @@ -138,6 +144,7 @@ def test_cannot_read_after_response_closed(port, client): assert "".format(port) in str(ex.value) assert "can no longer be read or streamed, since the response has already been closed" in str(ex.value) + def test_decompress_plain_no_header(client): # thanks to Xiang Yan for this test! account_name = "coretests" @@ -149,6 +156,7 @@ def test_decompress_plain_no_header(client): response.read() assert response.content == b"test" + def test_compress_plain_no_header(client): # thanks to Xiang Yan for this test! account_name = "coretests" @@ -159,6 +167,7 @@ def test_compress_plain_no_header(client): data = b"".join(list(iter)) assert data == b"test" + def test_decompress_compressed_no_header(client): # thanks to Xiang Yan for this test! account_name = "coretests" @@ -167,7 +176,8 @@ def test_decompress_compressed_no_header(client): response = client.send_request(request, stream=True) iter = response.iter_bytes() data = b"".join(list(iter)) - assert data == b'\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\n+I-.\x01\x00\x0c~\x7f\xd8\x04\x00\x00\x00' + assert data == b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\n+I-.\x01\x00\x0c~\x7f\xd8\x04\x00\x00\x00" + def test_decompress_compressed_header(client): # thanks to Xiang Yan for this test! @@ -180,6 +190,7 @@ def test_decompress_compressed_header(client): data = b"".join(list(iter)) assert data == b"test" + def test_iter_read(client): # thanks to McCoy Patiño for this test! request = HttpRequest("GET", "/basic/string") @@ -190,6 +201,7 @@ def test_iter_read(client): assert part assert response.text() + def test_iter_read_back_and_forth(client): # thanks to McCoy Patiño for this test! @@ -209,13 +221,15 @@ def test_iter_read_back_and_forth(client): with pytest.raises(ResponseNotReadError): response.text() + def test_stream_with_return_pipeline_response(client): request = HttpRequest("GET", "/basic/string") pipeline_response = client.send_request(request, stream=True, _return_pipeline_response=True) assert hasattr(pipeline_response, "http_request") assert hasattr(pipeline_response, "http_response") assert hasattr(pipeline_response, "context") - assert list(pipeline_response.http_response.iter_bytes()) == [b'Hello, world!'] + assert list(pipeline_response.http_response.iter_bytes()) == [b"Hello, world!"] + def test_error_reading(client): request = HttpRequest("GET", "/errors/403") @@ -230,18 +244,21 @@ def test_error_reading(client): assert response.content == b"" # try giving a really slow response, see what happens + def test_pass_kwarg_to_iter_bytes(client): request = HttpRequest("GET", "/basic/string") response = client.send_request(request, stream=True) for part in response.iter_bytes(chunk_size=5): assert part + def test_pass_kwarg_to_iter_raw(client): request = HttpRequest("GET", "/basic/string") response = client.send_request(request, stream=True) for part in response.iter_raw(chunk_size=5): assert part + def test_decompress_compressed_header(client): # expect plain text request = HttpRequest("GET", "/encoding/gzip") @@ -251,6 +268,7 @@ def test_decompress_compressed_header(client): assert response.content == content assert response.text() == "hello world" + def test_deflate_decompress_compressed_header(client): # expect plain text request = HttpRequest("GET", "/encoding/deflate") @@ -260,6 +278,7 @@ def test_deflate_decompress_compressed_header(client): assert response.content == content assert response.text() == "hi there" + def test_decompress_compressed_header_stream(client): # expect plain text request = HttpRequest("GET", "/encoding/gzip") @@ -269,6 +288,7 @@ def test_decompress_compressed_header_stream(client): assert response.content == content assert response.text() == "hello world" + def test_decompress_compressed_header_stream_body_content(client): # expect plain text request = HttpRequest("GET", "/encoding/gzip") diff --git a/sdk/core/azure-core/tests/test_retry_policy.py b/sdk/core/azure-core/tests/test_retry_policy.py index 994841b52d22..c5a018014b28 100644 --- a/sdk/core/azure-core/tests/test_retry_policy.py +++ b/sdk/core/azure-core/tests/test_retry_policy.py @@ -43,13 +43,10 @@ def test_retry_code_class_variables(): assert 429 in retry_policy._RETRY_CODES assert 501 not in retry_policy._RETRY_CODES + def test_retry_types(): history = ["1", "2", "3"] - settings = { - 'history': history, - 'backoff': 1, - 'max_backoff': 10 - } + settings = {"history": history, "backoff": 1, "max_backoff": 10} retry_policy = RetryPolicy() backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 @@ -62,7 +59,10 @@ def test_retry_types(): backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 -@pytest.mark.parametrize("retry_after_input,http_request,http_response", product(['0', '800', '1000', '1200'], HTTP_REQUESTS, HTTP_RESPONSES)) + +@pytest.mark.parametrize( + "retry_after_input,http_request,http_response", product(["0", "800", "1000", "1200"], HTTP_REQUESTS, HTTP_RESPONSES) +) def test_retry_after(retry_after_input, http_request, http_response): retry_policy = RetryPolicy() request = http_request("GET", "http://localhost") @@ -71,7 +71,7 @@ def test_retry_after(retry_after_input, http_request, http_response): pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) seconds = float(retry_after_input) - assert retry_after == seconds/1000.0 + assert retry_after == seconds / 1000.0 response.headers.pop("retry-after-ms") response.headers["Retry-After"] = retry_after_input retry_after = retry_policy.get_retry_after(pipeline_response) @@ -80,7 +80,10 @@ def test_retry_after(retry_after_input, http_request, http_response): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -@pytest.mark.parametrize("retry_after_input,http_request,http_response", product(['0', '800', '1000', '1200'], HTTP_REQUESTS, HTTP_RESPONSES)) + +@pytest.mark.parametrize( + "retry_after_input,http_request,http_response", product(["0", "800", "1000", "1200"], HTTP_REQUESTS, HTTP_RESPONSES) +) def test_x_ms_retry_after(retry_after_input, http_request, http_response): retry_policy = RetryPolicy() request = http_request("GET", "http://localhost") @@ -89,7 +92,7 @@ def test_x_ms_retry_after(retry_after_input, http_request, http_response): pipeline_response = PipelineResponse(request, response, None) retry_after = retry_policy.get_retry_after(pipeline_response) seconds = float(retry_after_input) - assert retry_after == seconds/1000.0 + assert retry_after == seconds / 1000.0 response.headers.pop("x-ms-retry-after-ms") response.headers["Retry-After"] = retry_after_input retry_after = retry_policy.get_retry_after(pipeline_response) @@ -98,15 +101,19 @@ def test_x_ms_retry_after(retry_after_input, http_request, http_response): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_retry_on_429(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._count = 0 + def __exit__(self, exc_type, exc_val, exc_tb): pass + def close(self): pass + def open(self): pass @@ -116,22 +123,26 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe response.status_code = 429 return response - http_request = http_request('GET', 'http://localhost/') - http_retry = RetryPolicy(retry_total = 1) + http_request = http_request("GET", "http://localhost/") + http_retry = RetryPolicy(retry_total=1) transport = MockTransport() pipeline = Pipeline(transport, [http_retry]) pipeline.run(http_request) assert transport._count == 2 + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_no_retry_on_201(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._count = 0 + def __exit__(self, exc_type, exc_val, exc_tb): pass + def close(self): pass + def open(self): pass @@ -143,30 +154,34 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe response.headers = headers return response - http_request = http_request('GET', 'http://localhost/') - http_retry = RetryPolicy(retry_total = 1) + http_request = http_request("GET", "http://localhost/") + http_retry = RetryPolicy(retry_total=1) transport = MockTransport() pipeline = Pipeline(transport, [http_retry]) pipeline.run(http_request) assert transport._count == 1 + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_retry_seekable_stream(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._first = True + def __exit__(self, exc_type, exc_val, exc_tb): pass + def close(self): pass + def open(self): pass def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineResponse if self._first: self._first = False - request.body.seek(0,2) - raise AzureError('fail on first') + request.body.seek(0, 2) + raise AzureError("fail on first") position = request.body.tell() assert position == 0 response = create_http_response(http_response, request, None) @@ -174,21 +189,25 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe return response data = BytesIO(b"Lots of dataaaa") - http_request = http_request('GET', 'http://localhost/') + http_request = http_request("GET", "http://localhost/") http_request.set_streamed_data_body(data) - http_retry = RetryPolicy(retry_total = 1) + http_retry = RetryPolicy(retry_total=1) pipeline = Pipeline(MockTransport(), [http_retry]) pipeline.run(http_request) + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_retry_seekable_file(http_request, http_response): class MockTransport(HttpTransport): def __init__(self): self._first = True + def __exit__(self, exc_type, exc_val, exc_tb): pass + def close(self): pass + def open(self): pass @@ -197,12 +216,12 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe self._first = False for value in request.files.values(): name, body = value[0], value[1] - if name and body and hasattr(body, 'read'): - body.seek(0,2) - raise AzureError('fail on first') + if name and body and hasattr(body, "read"): + body.seek(0, 2) + raise AzureError("fail on first") for value in request.files.values(): name, body = value[0], value[1] - if name and body and hasattr(body, 'read'): + if name and body and hasattr(body, "read"): position = body.tell() assert not position response = create_http_response(http_response, request, None) @@ -210,15 +229,15 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe return response file = tempfile.NamedTemporaryFile(delete=False) - file.write(b'Lots of dataaaa') + file.write(b"Lots of dataaaa") file.close() - http_request = http_request('GET', 'http://localhost/') - headers = {'Content-Type': "multipart/form-data"} + http_request = http_request("GET", "http://localhost/") + headers = {"Content-Type": "multipart/form-data"} http_request.headers = headers - with open(file.name, 'rb') as f: + with open(file.name, "rb") as f: form_data_content = { - 'fileContent': f, - 'fileName': f.name, + "fileContent": f, + "fileName": f.name, } http_request.set_formdata_body(form_data_content) http_retry = RetryPolicy(retry_total=1) @@ -226,6 +245,7 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe pipeline.run(http_request) os.unlink(f.name) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_retry_timeout(http_request): timeout = 1 @@ -245,6 +265,7 @@ def send(request, **kwargs): with pytest.raises(ServiceResponseTimeoutError): response = pipeline.run(http_request("GET", "http://localhost/")) + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_timeout_defaults(http_request, http_response): """When "timeout" is not set, the policy should not override the transport's timeout configuration""" @@ -266,8 +287,10 @@ def send(request, **kwargs): pipeline.run(http_request("GET", "http://localhost/")) assert transport.send.call_count == 1, "policy should not retry: its first send succeeded" + combinations = [(ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)] + @pytest.mark.parametrize( "combinations,http_request", product(combinations, HTTP_REQUESTS), @@ -275,7 +298,7 @@ def send(request, **kwargs): def test_does_not_sleep_after_timeout(combinations, http_request): # With default settings policy will sleep twice before exhausting its retries: 1.6s, 3.2s. # It should not sleep the second time when given timeout=1 - transport_error,expected_timeout_error = combinations + transport_error, expected_timeout_error = combinations timeout = 1 transport = Mock( diff --git a/sdk/core/azure-core/tests/test_serialization.py b/sdk/core/azure-core/tests/test_serialization.py index 01630132b58c..e76d9e80a86b 100644 --- a/sdk/core/azure-core/tests/test_serialization.py +++ b/sdk/core/azure-core/tests/test_serialization.py @@ -48,26 +48,34 @@ def to_dict(self): class NegativeUtcOffset(tzinfo): """tzinfo class with UTC offset of -12 hours""" + _offset = timedelta(seconds=-43200) _dst = timedelta(0) _name = "-1200" + def utcoffset(self, dt): return self.__class__._offset + def dst(self, dt): return self.__class__._dst + def tzname(self, dt): return self.__class__._name class PositiveUtcOffset(tzinfo): """tzinfo class with UTC offset of +12 hours""" + _offset = timedelta(seconds=43200) _dst = timedelta(0) _name = "+1200" + def utcoffset(self, dt): return self.__class__._offset + def dst(self, dt): return self.__class__._dst + def tzname(self, dt): return self.__class__._name @@ -77,32 +85,39 @@ def test_NULL_is_falsy(): assert bool(NULL) is False assert NULL is NULL + @pytest.fixture def json_dumps_with_encoder(): def func(obj): return json.dumps(obj, cls=AzureJSONEncoder) + return func + def test_bytes(json_dumps_with_encoder): test_bytes = b"mybytes" result = json.loads(json_dumps_with_encoder(test_bytes)) assert base64.b64decode(result) == test_bytes - + + def test_byte_array_ascii(json_dumps_with_encoder): test_byte_array = bytearray("mybytes", "ascii") result = json.loads(json_dumps_with_encoder(test_byte_array)) assert base64.b64decode(result) == test_byte_array + def test_byte_array_utf8(json_dumps_with_encoder): test_byte_array = bytearray("mybytes", "utf-8") result = json.loads(json_dumps_with_encoder(test_byte_array)) assert base64.b64decode(result) == test_byte_array + def test_byte_array_utf16(json_dumps_with_encoder): test_byte_array = bytearray("mybytes", "utf-16") result = json.loads(json_dumps_with_encoder(test_byte_array)) assert base64.b64decode(result) == test_byte_array + def test_dictionary_basic(json_dumps_with_encoder): test_obj = { "string": "myid", @@ -115,6 +130,7 @@ def test_dictionary_basic(json_dumps_with_encoder): assert json.dumps(test_obj) == complex_serialized assert json.loads(complex_serialized) == test_obj + def test_model_basic(json_dumps_with_encoder): class BasicModel(SerializerMixin): def __init__(self): @@ -126,7 +142,7 @@ def __init__(self): self.bytes_data = b"data as bytes" expected = BasicModel() - expected_bytes = "data as bytes" if sys.version_info.major == 2 else "ZGF0YSBhcyBieXRlcw==" # cspell:disable-line + expected_bytes = "data as bytes" if sys.version_info.major == 2 else "ZGF0YSBhcyBieXRlcw==" # cspell:disable-line expected_dict = { "string": "myid", "number": 42, @@ -137,46 +153,49 @@ def __init__(self): } assert json.loads(json_dumps_with_encoder(expected.to_dict())) == expected_dict + def test_dictionary_datetime(json_dumps_with_encoder): test_obj = { "timedelta": timedelta(1), "date": date(2021, 5, 12), - "datetime": datetime.strptime('2012-02-24T00:53:52.780Z', "%Y-%m-%dT%H:%M:%S.%fZ"), - "time": time(11,12,13), + "datetime": datetime.strptime("2012-02-24T00:53:52.780Z", "%Y-%m-%dT%H:%M:%S.%fZ"), + "time": time(11, 12, 13), } expected = { "timedelta": "P1DT00H00M00S", "date": "2021-05-12", - "datetime": '2012-02-24T00:53:52.780000Z', - 'time': '11:12:13', + "datetime": "2012-02-24T00:53:52.780000Z", + "time": "11:12:13", } assert json.loads(json_dumps_with_encoder(test_obj)) == expected + def test_model_datetime(json_dumps_with_encoder): class DatetimeModel(SerializerMixin): def __init__(self): self.timedelta = timedelta(1) self.date = date(2021, 5, 12) - self.datetime = datetime.strptime('2012-02-24T00:53:52.780Z', "%Y-%m-%dT%H:%M:%S.%fZ") - self.time = time(11,12,13) + self.datetime = datetime.strptime("2012-02-24T00:53:52.780Z", "%Y-%m-%dT%H:%M:%S.%fZ") + self.time = time(11, 12, 13) expected = DatetimeModel() expected_dict = { "timedelta": "P1DT00H00M00S", "date": "2021-05-12", - "datetime": '2012-02-24T00:53:52.780000Z', - 'time': '11:12:13', + "datetime": "2012-02-24T00:53:52.780000Z", + "time": "11:12:13", } assert json.loads(json_dumps_with_encoder(expected.to_dict())) == expected_dict + def test_model_key_vault(json_dumps_with_encoder): class Attributes(SerializerMixin): def __init__(self): self.enabled = True - self.not_before = datetime.strptime('2012-02-24T00:53:52.780Z', "%Y-%m-%dT%H:%M:%S.%fZ") - self.expires = datetime.strptime('2032-02-24T00:53:52.780Z', "%Y-%m-%dT%H:%M:%S.%fZ") - self.created = datetime.strptime('2020-02-24T00:53:52.780Z', "%Y-%m-%dT%H:%M:%S.%fZ") - self.updated = datetime.strptime('2021-02-24T00:53:52.780Z', "%Y-%m-%dT%H:%M:%S.%fZ") + self.not_before = datetime.strptime("2012-02-24T00:53:52.780Z", "%Y-%m-%dT%H:%M:%S.%fZ") + self.expires = datetime.strptime("2032-02-24T00:53:52.780Z", "%Y-%m-%dT%H:%M:%S.%fZ") + self.created = datetime.strptime("2020-02-24T00:53:52.780Z", "%Y-%m-%dT%H:%M:%S.%fZ") + self.updated = datetime.strptime("2021-02-24T00:53:52.780Z", "%Y-%m-%dT%H:%M:%S.%fZ") class ResourceId(SerializerMixin): def __init__(self): @@ -198,7 +217,9 @@ def __init__(self): self._tags = None expected = Properties() - expected_bytes = "thumbprint bytes" if sys.version_info.major == 2 else "dGh1bWJwcmludCBieXRlcw==" # cspell:disable-line + expected_bytes = ( + "thumbprint bytes" if sys.version_info.major == 2 else "dGh1bWJwcmludCBieXRlcw==" + ) # cspell:disable-line expected_dict = { "_attributes": { "enabled": True, @@ -221,14 +242,15 @@ def __init__(self): } assert json.loads(json_dumps_with_encoder(expected.to_dict())) == expected_dict + def test_serialize_datetime(json_dumps_with_encoder): - date_obj = datetime.strptime('2015-01-01T00:00:00', "%Y-%m-%dT%H:%M:%S") + date_obj = datetime.strptime("2015-01-01T00:00:00", "%Y-%m-%dT%H:%M:%S") date_str = json_dumps_with_encoder(date_obj) assert date_str == '"2015-01-01T00:00:00Z"' - date_obj = datetime.strptime('1999-12-31T23:59:59', "%Y-%m-%dT%H:%M:%S").replace(tzinfo=NegativeUtcOffset()) + date_obj = datetime.strptime("1999-12-31T23:59:59", "%Y-%m-%dT%H:%M:%S").replace(tzinfo=NegativeUtcOffset()) date_str = json_dumps_with_encoder(date_obj) assert date_str == '"2000-01-01T11:59:59Z"' @@ -246,44 +268,49 @@ def test_serialize_datetime(json_dumps_with_encoder): date_str = json_dumps_with_encoder(date_obj) assert date_str == '"9999-12-31T23:59:59.999999Z"' - date_obj = datetime.strptime('2012-02-24T00:53:52.000001Z', "%Y-%m-%dT%H:%M:%S.%fZ") + date_obj = datetime.strptime("2012-02-24T00:53:52.000001Z", "%Y-%m-%dT%H:%M:%S.%fZ") date_str = json_dumps_with_encoder(date_obj) assert date_str == '"2012-02-24T00:53:52.000001Z"' - date_obj = datetime.strptime('2012-02-24T00:53:52.780Z', "%Y-%m-%dT%H:%M:%S.%fZ") + date_obj = datetime.strptime("2012-02-24T00:53:52.780Z", "%Y-%m-%dT%H:%M:%S.%fZ") date_str = json_dumps_with_encoder(date_obj) assert date_str == '"2012-02-24T00:53:52.780000Z"' + def test_serialize_datetime_subclass(json_dumps_with_encoder): - date_obj = DatetimeSubclass.strptime('2012-02-24T00:53:52.780Z', "%Y-%m-%dT%H:%M:%S.%fZ") + date_obj = DatetimeSubclass.strptime("2012-02-24T00:53:52.780Z", "%Y-%m-%dT%H:%M:%S.%fZ") date_str = json_dumps_with_encoder(date_obj) assert date_str == '"2012-02-24T00:53:52.780000Z"' + def test_serialize_time(json_dumps_with_encoder): - time_str = json_dumps_with_encoder(time(11,22,33)) + time_str = json_dumps_with_encoder(time(11, 22, 33)) assert time_str == '"11:22:33"' - time_str = json_dumps_with_encoder(time(11,22,33,444444)) + time_str = json_dumps_with_encoder(time(11, 22, 33, 444444)) assert time_str == '"11:22:33.444444"' + class BasicEnum(Enum): val = "Basic" + class StringEnum(str, Enum): val = "string" + class IntEnum(int, Enum): val = 1 + class FloatEnum(float, Enum): val = 1.5 + def test_dictionary_enum(json_dumps_with_encoder): - test_obj = { - "basic": BasicEnum.val - } + test_obj = {"basic": BasicEnum.val} with pytest.raises(TypeError): json_dumps_with_encoder(test_obj) @@ -291,18 +318,14 @@ def test_dictionary_enum(json_dumps_with_encoder): "basic": BasicEnum.val.value, "string": StringEnum.val.value, "int": IntEnum.val.value, - "float": FloatEnum.val.value - } - expected = { - "basic": "Basic", - "string": "string", - "int": 1, - "float": 1.5 + "float": FloatEnum.val.value, } + expected = {"basic": "Basic", "string": "string", "int": 1, "float": 1.5} serialized = json_dumps_with_encoder(test_obj) assert json.dumps(test_obj) == serialized assert json.loads(serialized) == expected + def test_model_enum(json_dumps_with_encoder): class BasicEnumModel: def __init__(self): @@ -319,21 +342,16 @@ def __init__(self): self.float = FloatEnum.val expected = EnumModel() - expected_dict = { - "basic": "Basic", - "string": "string", - "int": 1, - "float": 1.5 - } + expected_dict = {"basic": "Basic", "string": "string", "int": 1, "float": 1.5} assert json.loads(json_dumps_with_encoder(expected.to_dict())) == expected_dict + def test_dictionary_none(json_dumps_with_encoder): assert json_dumps_with_encoder(None) == json.dumps(None) - test_obj = { - "entry": None - } + test_obj = {"entry": None} assert json.loads(json_dumps_with_encoder(test_obj)) == test_obj + def test_model_none(json_dumps_with_encoder): class NoneModel(SerializerMixin): def __init__(self): @@ -343,6 +361,7 @@ def __init__(self): expected_dict = {"entry": None} assert json.loads(json_dumps_with_encoder(expected.to_dict())) == expected_dict + def test_dictionary_empty_collections(json_dumps_with_encoder): test_obj = { "dictionary": {}, @@ -352,6 +371,7 @@ def test_dictionary_empty_collections(json_dumps_with_encoder): assert json.dumps(test_obj) == json_dumps_with_encoder(test_obj) assert json.loads(json_dumps_with_encoder(test_obj)) == test_obj + def test_model_empty_collections(json_dumps_with_encoder): class EmptyCollectionsModel(SerializerMixin): def __init__(self): @@ -365,6 +385,7 @@ def __init__(self): } assert json.loads(json_dumps_with_encoder(expected.to_dict())) == expected_dict + def test_model_inheritance(json_dumps_with_encoder): class ParentModel(SerializerMixin): def __init__(self): @@ -382,6 +403,7 @@ def __init__(self): } assert json.loads(json_dumps_with_encoder(expected.to_dict())) == expected_dict + def test_model_recursion(json_dumps_with_encoder): class RecursiveModel(SerializerMixin): def __init__(self): @@ -404,7 +426,7 @@ def __init__(self): "list_of_me": None, "dict_of_me": None, "dict_of_list_of_me": None, - "list_of_dict_of_me": None + "list_of_dict_of_me": None, } ], "dict_of_me": { @@ -413,7 +435,7 @@ def __init__(self): "list_of_me": None, "dict_of_me": None, "dict_of_list_of_me": None, - "list_of_dict_of_me": None + "list_of_dict_of_me": None, } }, "dict_of_list_of_me": { @@ -423,19 +445,20 @@ def __init__(self): "list_of_me": None, "dict_of_me": None, "dict_of_list_of_me": None, - "list_of_dict_of_me": None + "list_of_dict_of_me": None, } ] }, "list_of_dict_of_me": [ - {"me": { + { + "me": { "name": "it's me!", "list_of_me": None, "dict_of_me": None, "dict_of_list_of_me": None, - "list_of_dict_of_me": None + "list_of_dict_of_me": None, } } - ] + ], } assert json.loads(json_dumps_with_encoder(expected.to_dict())) == expected_dict diff --git a/sdk/core/azure-core/tests/test_settings.py b/sdk/core/azure-core/tests/test_settings.py index fd228cbcae08..368094f3bc35 100644 --- a/sdk/core/azure-core/tests/test_settings.py +++ b/sdk/core/azure-core/tests/test_settings.py @@ -204,16 +204,12 @@ def test_config(self): def test_defaults(self): val = m.settings.defaults # assert isinstance(val, tuple) - defaults = m.settings.config( - log_level=20, tracing_enabled=False, tracing_implementation=None - ) + defaults = m.settings.config(log_level=20, tracing_enabled=False, tracing_implementation=None) assert val.log_level == defaults.log_level assert val.tracing_enabled == defaults.tracing_enabled assert val.tracing_implementation == defaults.tracing_implementation os.environ["AZURE_LOG_LEVEL"] = "debug" - defaults = m.settings.config( - log_level=20, tracing_enabled=False, tracing_implementation=None - ) + defaults = m.settings.config(log_level=20, tracing_enabled=False, tracing_implementation=None) assert val.log_level == defaults.log_level assert val.tracing_enabled == defaults.tracing_enabled assert val.tracing_implementation == defaults.tracing_implementation diff --git a/sdk/core/azure-core/tests/test_stream_generator.py b/sdk/core/azure-core/tests/test_stream_generator.py index d572a154d936..eeacb47f077c 100644 --- a/sdk/core/azure-core/tests/test_stream_generator.py +++ b/sdk/core/azure-core/tests/test_stream_generator.py @@ -9,12 +9,20 @@ ) from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.transport._requests_basic import StreamDownloadGenerator + try: from unittest import mock except ImportError: import mock import pytest -from utils import HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES, create_http_response, create_transport_response, request_and_responses_product +from utils import ( + HTTP_RESPONSES, + REQUESTS_TRANSPORT_RESPONSES, + create_http_response, + create_transport_response, + request_and_responses_product, +) + @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_connection_error_response(http_request, http_response): @@ -24,13 +32,15 @@ def __init__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass + def close(self): pass + def open(self): pass def send(self, request, **kwargs): - request = http_request('GET', 'http://localhost/') + request = http_request("GET", "http://localhost/") response = create_http_response(http_response, request, None) response.status_code = 200 return response @@ -50,22 +60,23 @@ def stream(self, chunk_size, decode_content=False): while True: yield b"test" - class MockInternalResponse(): + class MockInternalResponse: def __init__(self): self.raw = MockTransport() def close(self): pass - http_request = http_request('GET', 'http://localhost/') + http_request = http_request("GET", "http://localhost/") pipeline = Pipeline(MockTransport()) http_response = create_http_response(http_response, http_request, None) http_response.internal_response = MockInternalResponse() stream = StreamDownloadGenerator(pipeline, http_response, decompress=False) - with mock.patch('time.sleep', return_value=None): + with mock.patch("time.sleep", return_value=None): with pytest.raises(requests.exceptions.ConnectionError): stream.__next__() + @pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) def test_response_streaming_error_behavior(http_response): # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 diff --git a/sdk/core/azure-core/tests/test_streaming.py b/sdk/core/azure-core/tests/test_streaming.py index 25a08e1bf3b2..3ce3099fedda 100644 --- a/sdk/core/azure-core/tests/test_streaming.py +++ b/sdk/core/azure-core/tests/test_streaming.py @@ -30,6 +30,7 @@ from azure.core.pipeline.transport import RequestsTransport from utils import HTTP_REQUESTS + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_decompress_plain_no_header(http_request): # expect plain text @@ -42,9 +43,10 @@ def test_decompress_plain_no_header(http_request): response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) content = b"".join(list(data)) - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_compress_plain_no_header_offline(port, http_request): # cspell:disable-next-line @@ -56,9 +58,10 @@ def test_compress_plain_no_header_offline(port, http_request): response.raise_for_status() data = response.stream_download(sender, decompress=False) content = b"".join(list(data)) - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.live_test_only @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_compress_plain_no_header(http_request): @@ -72,9 +75,10 @@ def test_compress_plain_no_header(http_request): response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) content = b"".join(list(data)) - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_decompress_compressed_no_header(http_request): # expect compressed text @@ -88,11 +92,12 @@ def test_decompress_compressed_no_header(http_request): data = response.stream_download(client._pipeline, decompress=True) content = b"".join(list(data)) try: - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert False except UnicodeDecodeError: pass + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_compress_compressed_no_header_offline(port, http_request): # expect compressed text @@ -103,7 +108,8 @@ def test_compress_compressed_no_header_offline(port, http_request): data = response.stream_download(client._pipeline, decompress=False) content = b"".join(list(data)) with pytest.raises(UnicodeDecodeError): - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") + @pytest.mark.live_test_only @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @@ -119,11 +125,12 @@ def test_compress_compressed_no_header(http_request): data = response.stream_download(client._pipeline, decompress=False) content = b"".join(list(data)) try: - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert False except UnicodeDecodeError: pass + @pytest.mark.live_test_only @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_decompress_plain_header(http_request): @@ -139,6 +146,7 @@ def test_decompress_plain_header(http_request): with pytest.raises(DecodeError): list(data) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_decompress_plain_header_offline(port, http_request): request = http_request(method="GET", url="http://localhost:{}/streams/compressed".format(port)) @@ -149,6 +157,7 @@ def test_decompress_plain_header_offline(port, http_request): with pytest.raises(DecodeError): list(data) + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_compress_plain_header(http_request): # expect plain text @@ -161,9 +170,10 @@ def test_compress_plain_header(http_request): response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) content = b"".join(list(data)) - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.live_test_only @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_decompress_compressed_header(http_request): @@ -177,9 +187,10 @@ def test_decompress_compressed_header(http_request): response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) content = b"".join(list(data)) - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_decompress_compressed_header_offline(port, http_request): client = PipelineClient("") @@ -189,9 +200,10 @@ def test_decompress_compressed_header_offline(port, http_request): response.raise_for_status() data = response.stream_download(sender, decompress=True) content = b"".join(list(data)) - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert decoded == "test" + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_compress_compressed_header(http_request): # expect compressed text @@ -205,7 +217,7 @@ def test_compress_compressed_header(http_request): data = response.stream_download(client._pipeline, decompress=False) content = b"".join(list(data)) try: - decoded = content.decode('utf-8') + decoded = content.decode("utf-8") assert False except UnicodeDecodeError: pass diff --git a/sdk/core/azure-core/tests/test_testserver.py b/sdk/core/azure-core/tests/test_testserver.py index 544778e32a79..507c08ae1156 100644 --- a/sdk/core/azure-core/tests/test_testserver.py +++ b/sdk/core/azure-core/tests/test_testserver.py @@ -26,8 +26,10 @@ from azure.core.pipeline.transport import RequestsTransport from utils import HTTP_REQUESTS import pytest + """This file does a simple call to the testserver to make sure we can use the testserver""" + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_smoke(port, http_request): request = http_request(method="GET", url="http://localhost:{}/basic/string".format(port)) diff --git a/sdk/core/azure-core/tests/test_tracing_decorator.py b/sdk/core/azure-core/tests/test_tracing_decorator.py index 5cabf2f247a4..4ee7ee621d96 100644 --- a/sdk/core/azure-core/tests/test_tracing_decorator.py +++ b/sdk/core/azure-core/tests/test_tracing_decorator.py @@ -21,6 +21,7 @@ from tracing_common import FakeSpan from utils import HTTP_REQUESTS + @pytest.fixture(scope="module") def fake_span(): settings.tracing_implementation.set_value(FakeSpan) @@ -53,7 +54,7 @@ def make_request(self, numb_times, **kwargs): return None response = self.pipeline.run(self.request, **kwargs) self.get_foo(merge_span=True) - kwargs['merge_span'] = True + kwargs["merge_span"] = True self.make_request(numb_times - 1, **kwargs) return response @@ -74,7 +75,7 @@ def get_foo(self): def check_name_is_different(self): time.sleep(0.001) - @distributed_trace(tracing_attributes={'foo': 'bar'}) + @distributed_trace(tracing_attributes={"foo": "bar"}) def tracing_attr(self): time.sleep(0.001) @@ -100,7 +101,6 @@ def test_get_function_and_class_name(http_request): @pytest.mark.usefixtures("fake_span") class TestDecorator(object): - @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_decorator_tracing_attr(self, http_request): with FakeSpan(name="parent") as parent: @@ -111,7 +111,7 @@ def test_decorator_tracing_attr(self, http_request): assert parent.children[0].name == "MockClient.__init__" assert parent.children[1].name == "MockClient.tracing_attr" assert parent.children[1].kind == SpanKind.INTERNAL - assert parent.children[1].attributes == {'foo': 'bar'} + assert parent.children[1].attributes == {"foo": "bar"} @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_decorator_has_different_name(self, http_request): @@ -188,8 +188,7 @@ def test_span_complicated(self, http_request): @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_span_with_exception(self, http_request): - """Assert that if an exception is raised, the next sibling method is actually a sibling span. - """ + """Assert that if an exception is raised, the next sibling method is actually a sibling span.""" with FakeSpan(name="parent") as parent: client = MockClient(http_request) try: @@ -202,5 +201,5 @@ def test_span_with_exception(self, http_request): assert parent.children[0].name == "MockClient.__init__" assert parent.children[1].name == "MockClient.raising_exception" # Exception should propagate status for Opencensus - assert parent.children[1].status == 'Something went horribly wrong here' + assert parent.children[1].status == "Something went horribly wrong here" assert parent.children[2].name == "MockClient.get_foo" diff --git a/sdk/core/azure-core/tests/test_tracing_policy.py b/sdk/core/azure-core/tests/test_tracing_policy.py index b65dd893687f..4012b028be1a 100644 --- a/sdk/core/azure-core/tests/test_tracing_policy.py +++ b/sdk/core/azure-core/tests/test_tracing_policy.py @@ -37,7 +37,7 @@ def test_distributed_tracing_policy_solo(http_request, http_response): response.status_code = 202 response.headers["x-ms-request-id"] = "some request id" - assert request.headers.get("traceparent") == '123456789' + assert request.headers.get("traceparent") == "123456789" policy.on_response(pipeline_request, PipelineResponse(request, response, PipelineContext(None))) time.sleep(0.001) @@ -75,9 +75,7 @@ def test_distributed_tracing_policy_attributes(http_request, http_response): """Test policy with no other policy and happy path""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: - policy = DistributedTracingPolicy(tracing_attributes={ - 'myattr': 'myvalue' - }) + policy = DistributedTracingPolicy(tracing_attributes={"myattr": "myvalue"}) request = http_request("GET", "http://localhost/temp?query=query") @@ -133,7 +131,7 @@ def test_distributed_tracing_policy_badurl(caplog, http_request, http_response): def test_distributed_tracing_policy_with_user_agent(http_request, http_response): """Test policy working with user agent.""" settings.tracing_implementation.set_value(FakeSpan) - with mock.patch.dict('os.environ', {"AZURE_HTTP_USER_AGENT": "mytools"}): + with mock.patch.dict("os.environ", {"AZURE_HTTP_USER_AGENT": "mytools"}): with FakeSpan(name="parent") as root_span: policy = DistributedTracingPolicy() @@ -152,7 +150,7 @@ def test_distributed_tracing_policy_with_user_agent(http_request, http_response) response.headers["x-ms-request-id"] = "some request id" pipeline_response = PipelineResponse(request, response, PipelineContext(None)) - assert request.headers.get("traceparent") == '123456789' + assert request.headers.get("traceparent") == "123456789" policy.on_response(pipeline_request, pipeline_response) @@ -185,7 +183,7 @@ def test_distributed_tracing_policy_with_user_agent(http_request, http_response) assert network_span.attributes.get("x-ms-request-id") is None assert network_span.attributes.get("http.status_code") == 504 # Exception should propagate status for Opencensus - assert network_span.status == 'Transport trouble' + assert network_span.status == "Transport trouble" @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) @@ -214,7 +212,7 @@ def operation_namer(http_request): assert http_request is request return "operation level name" - pipeline_request.context.options['network_span_namer'] = operation_namer + pipeline_request.context.options["network_span_namer"] = operation_namer policy.on_request(pipeline_request) diff --git a/sdk/core/azure-core/tests/test_universal_pipeline.py b/sdk/core/azure-core/tests/test_universal_pipeline.py index 14df6fcf6eb1..dfc3477afe24 100644 --- a/sdk/core/azure-core/tests/test_universal_pipeline.py +++ b/sdk/core/azure-core/tests/test_universal_pipeline.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -23,9 +23,10 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import logging import pickle + try: from unittest import mock except ImportError: @@ -35,12 +36,7 @@ import pytest from azure.core.exceptions import DecodeError, AzureError -from azure.core.pipeline import ( - Pipeline, - PipelineResponse, - PipelineRequest, - PipelineContext -) +from azure.core.pipeline import Pipeline, PipelineResponse, PipelineRequest, PipelineContext from azure.core.pipeline.policies import ( NetworkTraceLoggingPolicy, @@ -49,20 +45,26 @@ RetryPolicy, HTTPPolicy, ) -from utils import HTTP_REQUESTS, create_http_request, HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES, create_http_response, create_transport_response, request_and_responses_product +from utils import ( + HTTP_REQUESTS, + create_http_request, + HTTP_RESPONSES, + REQUESTS_TRANSPORT_RESPONSES, + create_http_response, + create_transport_response, + request_and_responses_product, +) from azure.core.pipeline._tools import is_rest + def test_pipeline_context(): - kwargs={ - 'stream':True, - 'cont_token':"bla" - } - context = PipelineContext('transport', **kwargs) - context['foo'] = 'bar' - context['xyz'] = '123' - context['deserialized_data'] = 'marvelous' - - assert context['foo'] == 'bar' + kwargs = {"stream": True, "cont_token": "bla"} + context = PipelineContext("transport", **kwargs) + context["foo"] = "bar" + context["xyz"] = "123" + context["deserialized_data"] = "marvelous" + + assert context["foo"] == "bar" assert context.options == kwargs with pytest.raises(TypeError): @@ -71,15 +73,15 @@ def test_pipeline_context(): with pytest.raises(TypeError): context.update({}) - assert context.pop('foo') == 'bar' - assert 'foo' not in context + assert context.pop("foo") == "bar" + assert "foo" not in context serialized = pickle.dumps(context) revived_context = pickle.loads(serialized) # nosec assert revived_context.options == kwargs assert revived_context.transport is None - assert 'deserialized_data' in revived_context + assert "deserialized_data" in revived_context assert len(revived_context) == 1 @@ -90,13 +92,14 @@ def __deepcopy__(self, memodict={}): raise ValueError() body = Non_deep_copyable() - request = create_http_request(http_request, 'GET', 'http://localhost/', {'user-agent': 'test_request_history'}) + request = create_http_request(http_request, "GET", "http://localhost/", {"user-agent": "test_request_history"}) request.body = body request_history = RequestHistory(request) assert request_history.http_request.headers == request.headers assert request_history.http_request.url == request.url assert request_history.http_request.method == request.method + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_request_history_type_error(http_request): class Non_deep_copyable(object): @@ -104,17 +107,18 @@ def __deepcopy__(self, memodict={}): raise TypeError() body = Non_deep_copyable() - request = create_http_request(http_request, 'GET', 'http://localhost/', {'user-agent': 'test_request_history'}) + request = create_http_request(http_request, "GET", "http://localhost/", {"user-agent": "test_request_history"}) request.body = body request_history = RequestHistory(request) assert request_history.http_request.headers == request.headers assert request_history.http_request.url == request.url assert request_history.http_request.method == request.method -@mock.patch('azure.core.pipeline.policies._universal._LOGGER') + +@mock.patch("azure.core.pipeline.policies._universal._LOGGER") @pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) def test_no_log(mock_http_logger, http_request, http_response): - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") request = PipelineRequest(universal_request, PipelineContext(None)) http_logger = NetworkTraceLoggingPolicy() response = PipelineResponse(request, create_http_response(http_response, universal_request, None), request.context) @@ -127,20 +131,20 @@ def test_no_log(mock_http_logger, http_request, http_response): mock_http_logger.reset_mock() # I can enable it per request - request.context.options['logging_enable'] = True + request.context.options["logging_enable"] = True http_logger.on_request(request) assert mock_http_logger.debug.call_count >= 1 mock_http_logger.reset_mock() - request.context.options['logging_enable'] = True + request.context.options["logging_enable"] = True http_logger.on_response(request, response) assert mock_http_logger.debug.call_count >= 1 mock_http_logger.reset_mock() # I can enable it per request (bool value should be honored) - request.context.options['logging_enable'] = False + request.context.options["logging_enable"] = False http_logger.on_request(request) mock_http_logger.debug.assert_not_called() - request.context.options['logging_enable'] = False + request.context.options["logging_enable"] = False http_logger.on_response(request, response) mock_http_logger.debug.assert_not_called() mock_http_logger.reset_mock() @@ -156,16 +160,16 @@ def test_no_log(mock_http_logger, http_request, http_response): # I can enable it globally and override it locally http_logger.enable_http_logger = True - request.context.options['logging_enable'] = False + request.context.options["logging_enable"] = False http_logger.on_request(request) mock_http_logger.debug.assert_not_called() - response.context['logging_enable'] = False + response.context["logging_enable"] = False http_logger.on_response(request, response) mock_http_logger.debug.assert_not_called() mock_http_logger.reset_mock() # Let's make this request a failure, retried twice - request.context.options['logging_enable'] = True + request.context.options["logging_enable"] = True http_logger.on_request(request) http_logger.on_response(request, response) @@ -178,26 +182,32 @@ def test_no_log(mock_http_logger, http_request, http_response): second_count = mock_http_logger.debug.call_count assert second_count == first_count * 2 + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_retry_without_http_response(http_request): class NaughtyPolicy(HTTPPolicy): def send(*args): - raise AzureError('boo') + raise AzureError("boo") policies = [RetryPolicy(), NaughtyPolicy()] pipeline = Pipeline(policies=policies, transport=None) with pytest.raises(AzureError): - pipeline.run(http_request('GET', url='https://foo.bar')) + pipeline.run(http_request("GET", url="https://foo.bar")) + -@pytest.mark.parametrize("http_request,http_response,requests_transport_response", request_and_responses_product(HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES)) +@pytest.mark.parametrize( + "http_request,http_response,requests_transport_response", + request_and_responses_product(HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES), +) def test_raw_deserializer(http_request, http_response, requests_transport_response): raw_deserializer = ContentDecodePolicy() context = PipelineContext(None, stream=False) - universal_request = http_request('GET', 'http://localhost/') + universal_request = http_request("GET", "http://localhost/") request = PipelineRequest(universal_request, context) def build_response(body, content_type=None): if is_rest(http_response): + class MockResponse(http_response): def __init__(self, body, content_type): super(MockResponse, self).__init__( @@ -220,6 +230,7 @@ def read(self): return self.content else: + class MockResponse(http_response): def __init__(self, body, content_type): super(MockResponse, self).__init__(None, None) @@ -242,10 +253,10 @@ def body(self): assert result.tag == "utf8groot" # The basic deserializer works with unicode XML - response = build_response(u''.encode('utf-8'), content_type="application/xml") + response = build_response(''.encode("utf-8"), content_type="application/xml") raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] - assert result.attrib["language"] == u"français" + assert result.attrib["language"] == "français" # Catch some weird situation where content_type is XML, but content is JSON response = build_response(b'{"ugly": true}', content_type="application/xml") @@ -254,12 +265,12 @@ def body(self): assert result["ugly"] is True # Be sure I catch the correct exception if it's neither XML nor JSON - response = build_response(b'gibberish', content_type="application/xml") + response = build_response(b"gibberish", content_type="application/xml") with pytest.raises(DecodeError) as err: raw_deserializer.on_response(request, response) assert err.value.response is response.http_response - response = build_response(b'{{gibberish}}', content_type="application/xml") + response = build_response(b"{{gibberish}}", content_type="application/xml") with pytest.raises(DecodeError) as err: raw_deserializer.on_response(request, response) assert err.value.response is response.http_response @@ -295,13 +306,13 @@ def body(self): assert result == "data" # Let text/plain let through - response = build_response(b'I am groot', content_type="text/plain") + response = build_response(b"I am groot", content_type="text/plain") raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] assert result == "I am groot" # Let text/plain let through + BOM - response = build_response(b'\xef\xbb\xbfI am groot', content_type="text/plain") + response = build_response(b"\xef\xbb\xbfI am groot", content_type="text/plain") raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] assert result == "I am groot" @@ -312,44 +323,52 @@ def body(self): req_response.headers["content-type"] = "application/json" req_response._content = b'{"success": true}' req_response._content_consumed = True - response = PipelineResponse(None, create_transport_response(requests_transport_response, None, req_response), PipelineContext(None, stream=False)) + response = PipelineResponse( + None, + create_transport_response(requests_transport_response, None, req_response), + PipelineContext(None, stream=False), + ) raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] assert result["success"] is True # I can enable it per request - request.context.options['response_encoding'] = 'utf-8' - response = build_response(b'\xc3\xa9', content_type="text/plain") + request.context.options["response_encoding"] = "utf-8" + response = build_response(b"\xc3\xa9", content_type="text/plain") raw_deserializer.on_request(request) raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] - assert result == u"é" + assert result == "é" assert response.context["response_encoding"] == "utf-8" - del request.context['response_encoding'] + del request.context["response_encoding"] # I can enable it globally raw_deserializer = ContentDecodePolicy(response_encoding="utf-8") - response = build_response(b'\xc3\xa9', content_type="text/plain") + response = build_response(b"\xc3\xa9", content_type="text/plain") raw_deserializer.on_request(request) raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] - assert result == u"é" + assert result == "é" assert response.context["response_encoding"] == "utf-8" - del request.context['response_encoding'] + del request.context["response_encoding"] # Per request is more important - request.context.options['response_encoding'] = 'utf-8-sig' - response = build_response(b'\xc3\xa9', content_type="text/plain") + request.context.options["response_encoding"] = "utf-8-sig" + response = build_response(b"\xc3\xa9", content_type="text/plain") raw_deserializer.on_request(request) raw_deserializer.on_response(request, response) result = response.context["deserialized_data"] - assert result == u"é" + assert result == "é" assert response.context["response_encoding"] == "utf-8-sig" - del request.context['response_encoding'] + del request.context["response_encoding"] + def test_json_merge_patch(): - assert ContentDecodePolicy.deserialize_from_text('{"hello": "world"}', mime_type="application/merge-patch+json") == {"hello": "world"} + assert ContentDecodePolicy.deserialize_from_text( + '{"hello": "world"}', mime_type="application/merge-patch+json" + ) == {"hello": "world"} + def test_json_regex(): assert not ContentDecodePolicy.JSON_REGEXP.match("text/plain") diff --git a/sdk/core/azure-core/tests/test_user_agent_policy.py b/sdk/core/azure-core/tests/test_user_agent_policy.py index afed968f149f..f4121d7eeffa 100644 --- a/sdk/core/azure-core/tests/test_user_agent_policy.py +++ b/sdk/core/azure-core/tests/test_user_agent_policy.py @@ -5,6 +5,7 @@ """Tests for the user agent policy.""" from azure.core.pipeline.policies import UserAgentPolicy from azure.core.pipeline import PipelineRequest, PipelineContext + try: from unittest import mock except ImportError: @@ -12,32 +13,33 @@ import pytest from utils import HTTP_REQUESTS + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_user_agent_policy(http_request): - user_agent = UserAgentPolicy(base_user_agent='foo') - assert user_agent._user_agent == 'foo' + user_agent = UserAgentPolicy(base_user_agent="foo") + assert user_agent._user_agent == "foo" - user_agent = UserAgentPolicy(sdk_moniker='foosdk/1.0.0') - assert user_agent._user_agent.startswith('azsdk-python-foosdk/1.0.0 Python') + user_agent = UserAgentPolicy(sdk_moniker="foosdk/1.0.0") + assert user_agent._user_agent.startswith("azsdk-python-foosdk/1.0.0 Python") - user_agent = UserAgentPolicy(base_user_agent='foo', user_agent='bar', user_agent_use_env=False) - assert user_agent._user_agent == 'bar foo' + user_agent = UserAgentPolicy(base_user_agent="foo", user_agent="bar", user_agent_use_env=False) + assert user_agent._user_agent == "bar foo" - request = http_request('GET', 'http://localhost/') + request = http_request("GET", "http://localhost/") pipeline_request = PipelineRequest(request, PipelineContext(None)) - pipeline_request.context.options['user_agent'] = 'xyz' + pipeline_request.context.options["user_agent"] = "xyz" user_agent.on_request(pipeline_request) - assert request.headers['User-Agent'] == 'xyz bar foo' + assert request.headers["User-Agent"] == "xyz bar foo" @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_user_agent_environ(http_request): - with mock.patch.dict('os.environ', {'AZURE_HTTP_USER_AGENT': "mytools"}): + with mock.patch.dict("os.environ", {"AZURE_HTTP_USER_AGENT": "mytools"}): policy = UserAgentPolicy(None) assert policy.user_agent.endswith("mytools") - request = http_request('GET', 'http://localhost/') + request = http_request("GET", "http://localhost/") policy.on_request(PipelineRequest(request, PipelineContext(None))) assert request.headers["user-agent"].endswith("mytools") diff --git a/sdk/core/azure-core/tests/test_utils.py b/sdk/core/azure-core/tests/test_utils.py index f7406e0f5488..55d2d485a10b 100644 --- a/sdk/core/azure-core/tests/test_utils.py +++ b/sdk/core/azure-core/tests/test_utils.py @@ -5,80 +5,94 @@ import pytest from azure.core.utils import case_insensitive_dict + @pytest.fixture() def accept_cases(): return ["accept", "Accept", "ACCEPT", "aCCePT"] + def test_case_insensitive_dict_basic(accept_cases): my_dict = case_insensitive_dict({"accept": "application/json"}) for accept_case in accept_cases: assert my_dict[accept_case] == "application/json" + def test_case_insensitive_dict_override(accept_cases): for accept_case in accept_cases: my_dict = case_insensitive_dict({accept_case: "should-not/be-me"}) my_dict["accept"] = "application/json" assert my_dict[accept_case] == my_dict["accept"] == "application/json" + def test_case_insensitive_dict_initialization(): - dict_response = { - "platformUpdateDomainCount": 5, - "platformFaultDomainCount": 3, - "virtualMachines": [] - } + dict_response = {"platformUpdateDomainCount": 5, "platformFaultDomainCount": 3, "virtualMachines": []} a = case_insensitive_dict(platformUpdateDomainCount=5, platformFaultDomainCount=3, virtualMachines=[]) - b = case_insensitive_dict(zip(['platformUpdateDomainCount', 'platformFaultDomainCount', 'virtualMachines'], [5, 3, []])) - c = case_insensitive_dict([('platformFaultDomainCount', 3), ('platformUpdateDomainCount', 5), ('virtualMachines', [])]) - d = case_insensitive_dict({'virtualMachines': [], 'platformFaultDomainCount': 3, 'platformUpdateDomainCount': 5}) - e = case_insensitive_dict({'platformFaultDomainCount': 3, 'virtualMachines': []}, platformUpdateDomainCount=5) + b = case_insensitive_dict( + zip(["platformUpdateDomainCount", "platformFaultDomainCount", "virtualMachines"], [5, 3, []]) + ) + c = case_insensitive_dict( + [("platformFaultDomainCount", 3), ("platformUpdateDomainCount", 5), ("virtualMachines", [])] + ) + d = case_insensitive_dict({"virtualMachines": [], "platformFaultDomainCount": 3, "platformUpdateDomainCount": 5}) + e = case_insensitive_dict({"platformFaultDomainCount": 3, "virtualMachines": []}, platformUpdateDomainCount=5) f = case_insensitive_dict(dict_response) g = case_insensitive_dict(**dict_response) assert a == b == c == d == e == f == g dicts = [a, b, c, d, e, f, g] for d in dicts: assert len(d) == 3 - assert d['platformUpdateDomainCount'] == d['platformupdatedomaincount'] == d['PLATFORMUPDATEDOMAINCOUNT'] == 5 - assert d['platformFaultDomainCount'] == d['platformfaultdomaincount'] == d['PLATFORMFAULTDOMAINCOUNT'] == 3 - assert d['virtualMachines'] == d['virtualmachines'] == d['VIRTUALMACHINES'] == [] + assert d["platformUpdateDomainCount"] == d["platformupdatedomaincount"] == d["PLATFORMUPDATEDOMAINCOUNT"] == 5 + assert d["platformFaultDomainCount"] == d["platformfaultdomaincount"] == d["PLATFORMFAULTDOMAINCOUNT"] == 3 + assert d["virtualMachines"] == d["virtualmachines"] == d["VIRTUALMACHINES"] == [] + def test_case_insensitive_dict_cant_compare(): my_dict = case_insensitive_dict({"accept": "application/json"}) assert my_dict != "accept" + def test_case_insensitive_dict_lowerkey_items(): my_dict = case_insensitive_dict({"accept": "application/json"}) - assert list(my_dict.lowerkey_items()) == [("accept","application/json")] + assert list(my_dict.lowerkey_items()) == [("accept", "application/json")] + -@pytest.mark.parametrize("other, expected", ( - ({"PLATFORMUPDATEDOMAINCOUNT": 5}, True), - ({}, False), - (None, False), -)) +@pytest.mark.parametrize( + "other, expected", + ( + ({"PLATFORMUPDATEDOMAINCOUNT": 5}, True), + ({}, False), + (None, False), + ), +) def test_case_insensitive_dict_equality(other, expected): my_dict = case_insensitive_dict({"platformUpdateDomainCount": 5}) result = my_dict == other assert result == expected + def test_case_insensitive_dict_keys(): keys = ["One", "TWO", "tHrEe", "four"] - my_dict = case_insensitive_dict({key:idx for idx, key in enumerate(keys,1)}) + my_dict = case_insensitive_dict({key: idx for idx, key in enumerate(keys, 1)}) dict_keys = list(my_dict.keys()) assert dict_keys == keys + def test_case_insensitive_copy(): keys = ["One", "TWO", "tHrEe", "four"] - my_dict = case_insensitive_dict({key:idx for idx, key in enumerate(keys, 1)}) + my_dict = case_insensitive_dict({key: idx for idx, key in enumerate(keys, 1)}) copy_my_dict = my_dict.copy() assert copy_my_dict is not my_dict assert copy_my_dict == my_dict + def test_case_insensitive_keys_present(accept_cases): my_dict = case_insensitive_dict({"accept": "application/json"}) for key in accept_cases: assert key in my_dict + def test_case_insensitive_keys_delete(accept_cases): my_dict = case_insensitive_dict({"accept": "application/json"}) @@ -87,9 +101,10 @@ def test_case_insensitive_keys_delete(accept_cases): assert key not in my_dict my_dict[key] = "application/json" + def test_case_iter(): keys = ["One", "TWO", "tHrEe", "four"] - my_dict = case_insensitive_dict({key:idx for idx, key in enumerate(keys, 1)}) - + my_dict = case_insensitive_dict({key: idx for idx, key in enumerate(keys, 1)}) + for key in my_dict: - assert key in keys \ No newline at end of file + assert key in keys diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/__init__.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/__init__.py index f34ff35d3f26..ccf09e38dbcb 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/__init__.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/__init__.py @@ -29,9 +29,11 @@ app.register_blueprint(xml_api, url_prefix="/xml") app.register_blueprint(headers_api, url_prefix="/headers") -@app.route('/health', methods=['GET']) + +@app.route("/health", methods=["GET"]) def latin_1_charset_utf8(): return Response(status=200) + if __name__ == "__main__": app.run(debug=True) diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/basic.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/basic.py index 86117c112810..02ab084b83d2 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/basic.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/basic.py @@ -12,61 +12,58 @@ ) from .helpers import jsonify, get_dict -basic_api = Blueprint('basic_api', __name__) +basic_api = Blueprint("basic_api", __name__) -@basic_api.route('/string', methods=['GET']) + +@basic_api.route("/string", methods=["GET"]) def string(): - return Response( - "Hello, world!", status=200, mimetype="text/plain" - ) + return Response("Hello, world!", status=200, mimetype="text/plain") -@basic_api.route('/lines', methods=['GET']) + +@basic_api.route("/lines", methods=["GET"]) def lines(): - return Response( - "Hello,\nworld!", status=200, mimetype="text/plain" - ) + return Response("Hello,\nworld!", status=200, mimetype="text/plain") + -@basic_api.route("/bytes", methods=['GET']) +@basic_api.route("/bytes", methods=["GET"]) def bytes(): - return Response( - "Hello, world!".encode(), status=200, mimetype="text/plain" - ) + return Response("Hello, world!".encode(), status=200, mimetype="text/plain") -@basic_api.route("/html", methods=['GET']) + +@basic_api.route("/html", methods=["GET"]) def html(): - return Response( - "Hello, world!", status=200, mimetype="text/html" - ) + return Response("Hello, world!", status=200, mimetype="text/html") -@basic_api.route("/json", methods=['GET']) + +@basic_api.route("/json", methods=["GET"]) def json(): - return Response( - '{"greeting": "hello", "recipient": "world"}', status=200, mimetype="application/json" - ) + return Response('{"greeting": "hello", "recipient": "world"}', status=200, mimetype="application/json") + -@basic_api.route("/complicated-json", methods=['POST']) +@basic_api.route("/complicated-json", methods=["POST"]) def complicated_json(): # thanks to Sean Kane for this test! - assert request.json['EmptyByte'] == '' - assert request.json['EmptyUnicode'] == '' - assert request.json['SpacesOnlyByte'] == ' ' - assert request.json['SpacesOnlyUnicode'] == ' ' - assert request.json['SpacesBeforeByte'] == ' Text' - assert request.json['SpacesBeforeUnicode'] == ' Text' - assert request.json['SpacesAfterByte'] == 'Text ' - assert request.json['SpacesAfterUnicode'] == 'Text ' - assert request.json['SpacesBeforeAndAfterByte'] == ' Text ' - assert request.json['SpacesBeforeAndAfterUnicode'] == ' Text ' - assert request.json[u'啊齄丂狛'] == u'ꀕ' - assert request.json['RowKey'] == 'test2' - assert request.json[u'啊齄丂狛狜'] == 'hello' + assert request.json["EmptyByte"] == "" + assert request.json["EmptyUnicode"] == "" + assert request.json["SpacesOnlyByte"] == " " + assert request.json["SpacesOnlyUnicode"] == " " + assert request.json["SpacesBeforeByte"] == " Text" + assert request.json["SpacesBeforeUnicode"] == " Text" + assert request.json["SpacesAfterByte"] == "Text " + assert request.json["SpacesAfterUnicode"] == "Text " + assert request.json["SpacesBeforeAndAfterByte"] == " Text " + assert request.json["SpacesBeforeAndAfterUnicode"] == " Text " + assert request.json["啊齄丂狛"] == "ꀕ" + assert request.json["RowKey"] == "test2" + assert request.json["啊齄丂狛狜"] == "hello" assert request.json["singlequote"] == "a''''b" assert request.json["doublequote"] == 'a""""b' assert request.json["None"] == None return Response(status=200) -@basic_api.route("/headers", methods=['GET']) + +@basic_api.route("/headers", methods=["GET"]) def headers(): return Response( status=200, @@ -74,9 +71,10 @@ def headers(): "lowercase-header": "lowercase", "ALLCAPS-HEADER": "ALLCAPS", "CamelCase-Header": "camelCase", - } + }, ) + @basic_api.route("/anything", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "TRACE"]) def anything(): return jsonify( diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/encoding.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/encoding.py index 1733cef3bf0f..d67eb2517149 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/encoding.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/encoding.py @@ -10,101 +10,92 @@ Blueprint, ) -encoding_api = Blueprint('encoding_api', __name__) +encoding_api = Blueprint("encoding_api", __name__) -@encoding_api.route('/latin-1', methods=['GET']) + +@encoding_api.route("/latin-1", methods=["GET"]) def latin_1(): - r = Response( - u"Latin 1: ÿ".encode("latin-1"), status=200 - ) + r = Response("Latin 1: ÿ".encode("latin-1"), status=200) r.headers["Content-Type"] = "text/plain; charset=latin-1" return r -@encoding_api.route('/latin-1-with-utf-8', methods=['GET']) + +@encoding_api.route("/latin-1-with-utf-8", methods=["GET"]) def latin_1_charset_utf8(): - r = Response( - u"Latin 1: ÿ".encode("latin-1"), status=200 - ) + r = Response("Latin 1: ÿ".encode("latin-1"), status=200) r.headers["Content-Type"] = "text/plain; charset=utf-8" return r -@encoding_api.route('/no-charset', methods=['GET']) + +@encoding_api.route("/no-charset", methods=["GET"]) def latin_1_no_charset(): - r = Response( - "Hello, world!", status=200 - ) + r = Response("Hello, world!", status=200) r.headers["Content-Type"] = "text/plain" return r -@encoding_api.route('/iso-8859-1', methods=['GET']) + +@encoding_api.route("/iso-8859-1", methods=["GET"]) def iso_8859_1(): - r = Response( - u"Accented: Österreich".encode("iso-8859-1"), status=200 # cspell:disable-line - ) + r = Response("Accented: Österreich".encode("iso-8859-1"), status=200) # cspell:disable-line r.headers["Content-Type"] = "text/plain" return r -@encoding_api.route('/emoji', methods=['GET']) + +@encoding_api.route("/emoji", methods=["GET"]) def emoji(): - r = Response( - u"👩", status=200 - ) + r = Response("👩", status=200) return r -@encoding_api.route('/emoji-family-skin-tone-modifier', methods=['GET']) + +@encoding_api.route("/emoji-family-skin-tone-modifier", methods=["GET"]) def emoji_family_skin_tone_modifier(): - r = Response( - u"👩🏻‍👩🏽‍👧🏾‍👦🏿 SSN: 859-98-0987", status=200 - ) + r = Response("👩🏻‍👩🏽‍👧🏾‍👦🏿 SSN: 859-98-0987", status=200) return r -@encoding_api.route('/korean', methods=['GET']) + +@encoding_api.route("/korean", methods=["GET"]) def korean(): - r = Response( - "아가", status=200 - ) + r = Response("아가", status=200) return r -@encoding_api.route('/json', methods=['GET']) + +@encoding_api.route("/json", methods=["GET"]) def json(): data = {"greeting": "hello", "recipient": "world"} content = dumps(data).encode("utf-16") - r = Response( - content, status=200 - ) + r = Response(content, status=200) r.headers["Content-Type"] = "application/json; charset=utf-16" return r -@encoding_api.route('/invalid-codec-name', methods=['GET']) + +@encoding_api.route("/invalid-codec-name", methods=["GET"]) def invalid_codec_name(): - r = Response( - "おはようございます。".encode("utf-8"), status=200 - ) + r = Response("おはようございます。".encode("utf-8"), status=200) r.headers["Content-Type"] = "text/plain; charset=invalid-codec-name" return r -@encoding_api.route('/no-charset', methods=['GET']) + +@encoding_api.route("/no-charset", methods=["GET"]) def no_charset(): - r = Response( - "Hello, world!", status=200 - ) + r = Response("Hello, world!", status=200) r.headers["Content-Type"] = "text/plain" return r -@encoding_api.route('/gzip', methods=['GET']) + +@encoding_api.route("/gzip", methods=["GET"]) def gzip_content_encoding(): r = Response( - b'\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\n\xcbH\xcd\xc9\xc9W(\xcf/\xcaI\x01\x00\x85\x11J\r\x0b\x00\x00\x00', status=200 + b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\n\xcbH\xcd\xc9\xc9W(\xcf/\xcaI\x01\x00\x85\x11J\r\x0b\x00\x00\x00", + status=200, ) r.headers["Content-Type"] = "text/plain" - r.headers['Content-Encoding'] = "gzip" + r.headers["Content-Encoding"] = "gzip" return r -@encoding_api.route('/deflate', methods=['GET']) + +@encoding_api.route("/deflate", methods=["GET"]) def deflate_content_encoding(): - r = Response( - b'\xcb\xc8T(\xc9H-J\x05\x00', status=200 - ) + r = Response(b"\xcb\xc8T(\xc9H-J\x05\x00", status=200) r.headers["Content-Type"] = "text/plain" - r.headers['Content-Encoding'] = "deflate" + r.headers["Content-Encoding"] = "deflate" return r diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/errors.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/errors.py index 410157c98148..87b274e1f6b7 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/errors.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/errors.py @@ -9,55 +9,55 @@ Blueprint, ) -errors_api = Blueprint('errors_api', __name__) +errors_api = Blueprint("errors_api", __name__) -@errors_api.route('/403', methods=['GET']) + +@errors_api.route("/403", methods=["GET"]) def get_403(): return Response(status=403) -@errors_api.route('/500', methods=['GET']) + +@errors_api.route("/500", methods=["GET"]) def get_500(): return Response(status=500) -@errors_api.route('/stream', methods=['GET']) + +@errors_api.route("/stream", methods=["GET"]) def get_stream(): class StreamingBody: def __iter__(self): yield b"Hello, " yield b"world!" + return Response(StreamingBody(), status=500) -@errors_api.route('/short-data', methods=['GET']) + +@errors_api.route("/short-data", methods=["GET"]) def get_short_data(): response = Response(b"X" * 4, status=200) response.automatically_set_content_length = False response.headers["Content-Length"] = "8" return response -@errors_api.route('/non-odatav4-body', methods=['GET']) + +@errors_api.route("/non-odatav4-body", methods=["GET"]) def get_non_odata_v4_response_body(): - return Response( - '{"code": 400, "error": {"global": ["MY-ERROR-MESSAGE-THAT-IS-COMING-FROM-THE-API"]}}', - status=400 - ) + return Response('{"code": 400, "error": {"global": ["MY-ERROR-MESSAGE-THAT-IS-COMING-FROM-THE-API"]}}', status=400) -@errors_api.route('/malformed-json', methods=['GET']) + +@errors_api.route("/malformed-json", methods=["GET"]) def get_malformed_json(): - return Response( - '{"code": 400, "error": {"global": ["MY-ERROR-MESSAGE-THAT-IS-COMING-FROM-THE-API"]', - status=400 - ) + return Response('{"code": 400, "error": {"global": ["MY-ERROR-MESSAGE-THAT-IS-COMING-FROM-THE-API"]', status=400) + -@errors_api.route('/text', methods=['GET']) +@errors_api.route("/text", methods=["GET"]) def get_text_body(): - return Response( - 'I am throwing an error', - status=400 - ) + return Response("I am throwing an error", status=400) + -@errors_api.route('/odatav4', methods=['GET']) +@errors_api.route("/odatav4", methods=["GET"]) def get_odatav4(): return Response( '{"error": {"code": "501", "message": "Unsupported functionality", "target": "query", "details": [{"code": "301", "target": "$search", "message": "$search query option not supported"}], "innererror": {"trace": [], "context": {}}}}', - status=400 + status=400, ) diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/headers.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/headers.py index ffb5232587b0..92679dd0d2ca 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/headers.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/headers.py @@ -5,15 +5,12 @@ # license information. # ------------------------------------------------------------------------- -from flask import ( - Response, - Blueprint, - request -) +from flask import Response, Blueprint, request -headers_api = Blueprint('headers_api', __name__) +headers_api = Blueprint("headers_api", __name__) -@headers_api.route("/case-insensitive", methods=['GET']) + +@headers_api.route("/case-insensitive", methods=["GET"]) def case_insensitive(): return Response( status=200, @@ -21,38 +18,33 @@ def case_insensitive(): "lowercase-header": "lowercase", "ALLCAPS-HEADER": "ALLCAPS", "CamelCase-Header": "camelCase", - } + }, ) -@headers_api.route("/empty", methods=['GET']) + +@headers_api.route("/empty", methods=["GET"]) def empty(): - return Response( - status=200, - headers={} - ) + return Response(status=200, headers={}) + -@headers_api.route("/duplicate/numbers", methods=['GET']) +@headers_api.route("/duplicate/numbers", methods=["GET"]) def duplicate_numbers(): - return Response( - status=200, - headers=[("a", "123"), ("a", "456"), ("b", "789")] - ) + return Response(status=200, headers=[("a", "123"), ("a", "456"), ("b", "789")]) + -@headers_api.route("/duplicate/case-insensitive", methods=['GET']) +@headers_api.route("/duplicate/case-insensitive", methods=["GET"]) def duplicate_case_insensitive(): return Response( - status=200, - headers=[("Duplicate-Header", "one"), ("Duplicate-Header", "two"), ("duplicate-header", "three")] + status=200, headers=[("Duplicate-Header", "one"), ("Duplicate-Header", "two"), ("duplicate-header", "three")] ) -@headers_api.route("/duplicate/commas", methods=['GET']) + +@headers_api.route("/duplicate/commas", methods=["GET"]) def duplicate_commas(): - return Response( - status=200, - headers=[("Set-Cookie", "a, b"), ("Set-Cookie", "c")] - ) + return Response(status=200, headers=[("Set-Cookie", "a, b"), ("Set-Cookie", "c")]) + -@headers_api.route("/ordered", methods=['GET']) +@headers_api.route("/ordered", methods=["GET"]) def ordered(): return Response( status=200, diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/helpers.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/helpers.py index 132fb6874b63..3421c5a97002 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/helpers.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/helpers.py @@ -12,57 +12,61 @@ import json ENV_HEADERS = ( - 'X-Varnish', - 'X-Request-Start', - 'X-Heroku-Queue-Depth', - 'X-Real-Ip', - 'X-Forwarded-Proto', - 'X-Forwarded-Protocol', - 'X-Forwarded-Ssl', - 'X-Heroku-Queue-Wait-Time', - 'X-Forwarded-For', - 'X-Heroku-Dynos-In-Use', - 'X-Forwarded-Protocol', - 'X-Forwarded-Port', - 'X-Request-Id', - 'Via', - 'Total-Route-Time', - 'Connect-Time' + "X-Varnish", + "X-Request-Start", + "X-Heroku-Queue-Depth", + "X-Real-Ip", + "X-Forwarded-Proto", + "X-Forwarded-Protocol", + "X-Forwarded-Ssl", + "X-Heroku-Queue-Wait-Time", + "X-Forwarded-For", + "X-Heroku-Dynos-In-Use", + "X-Forwarded-Protocol", + "X-Forwarded-Port", + "X-Request-Id", + "Via", + "Total-Route-Time", + "Connect-Time", ) + def assert_with_message(param_name, expected_value, actual_value): assert expected_value == actual_value, "Expected '{}' to be '{}', got '{}'".format( param_name, expected_value, actual_value ) + def jsonify(*args, **kwargs): response = flask_jsonify(*args, **kwargs) if not response.data.endswith(b"\n"): response.data += b"\n" return response + def get_url(request): """ Since we might be hosted behind a proxy, we need to check the X-Forwarded-Proto, X-Forwarded-Protocol, or X-Forwarded-SSL headers to find out what protocol was used to access us. """ - protocol = request.headers.get('X-Forwarded-Proto') or request.headers.get('X-Forwarded-Protocol') - if protocol is None and request.headers.get('X-Forwarded-Ssl') == 'on': - protocol = 'https' + protocol = request.headers.get("X-Forwarded-Proto") or request.headers.get("X-Forwarded-Protocol") + if protocol is None and request.headers.get("X-Forwarded-Ssl") == "on": + protocol = "https" if protocol is None: return request.url url = list(urlparse(request.url)) url[0] = protocol return urlunparse(url) + def get_files(): """Returns files dict from request context.""" files = dict() for k, v in request.files.items(): - content_type = request.files[k].content_type or 'application/octet-stream' + content_type = request.files[k].content_type or "application/octet-stream" val = json_safe(v.read(), content_type) if files.get(k): if not isinstance(files[k], list): @@ -73,12 +77,13 @@ def get_files(): return files + def get_headers(hide_env=True): """Returns headers dict from request context.""" headers = dict(request.headers.items()) - if hide_env and ('show_env' not in request.args): + if hide_env and ("show_env" not in request.args): for key in ENV_HEADERS: try: del headers[key] @@ -87,6 +92,7 @@ def get_headers(hide_env=True): return CaseInsensitiveDict(headers.items()) + def semiflatten(multi): """Convert a MultiDict into a regular dict. If there are more than one value for a key, the result will have a list of values for the key. Otherwise it @@ -100,7 +106,8 @@ def semiflatten(multi): else: return multi -def json_safe(string, content_type='application/octet-stream'): + +def json_safe(string, content_type="application/octet-stream"): """Returns JSON-safe version of `string`. If `string` is a Unicode string or a valid UTF-8, it is returned unmodified, @@ -112,28 +119,24 @@ def json_safe(string, content_type='application/octet-stream'): URL scheme was chosen for its simplicity. """ try: - string = string.decode('utf-8') + string = string.decode("utf-8") json.dumps(string) return string except (ValueError, TypeError): - return b''.join([ - b'data:', - content_type.encode('utf-8'), - b';base64,', - base64.b64encode(string) - ]).decode('utf-8') + return b"".join([b"data:", content_type.encode("utf-8"), b";base64,", base64.b64encode(string)]).decode("utf-8") + def get_dict(*keys, **extras): """Returns request dict of given keys.""" - _keys = ('url', 'args', 'form', 'data', 'origin', 'headers', 'files', 'json', 'method') + _keys = ("url", "args", "form", "data", "origin", "headers", "files", "json", "method") assert all(map(_keys.__contains__, keys)) data = request.data form = semiflatten(request.form) try: - _json = json.loads(data.decode('utf-8')) + _json = json.loads(data.decode("utf-8")) except (ValueError, TypeError): _json = None @@ -142,7 +145,7 @@ def get_dict(*keys, **extras): args=semiflatten(request.args), form=form, data=json_safe(data), - origin=request.headers.get('X-Forwarded-For', request.remote_addr), + origin=request.headers.get("X-Forwarded-For", request.remote_addr), headers=get_headers(), files=get_files(), json=_json, @@ -158,9 +161,9 @@ def get_dict(*keys, **extras): return out_d + def get_base_url(request): return "http://" + request.host -__all__ = ["assert_with_message", - "get_dict", - "jsonify"] \ No newline at end of file + +__all__ = ["assert_with_message", "get_dict", "jsonify"] diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/multipart.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/multipart.py index 9be44121d2d4..19bcf537db46 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/multipart.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/multipart.py @@ -12,7 +12,7 @@ ) from .helpers import assert_with_message -multipart_api = Blueprint('multipart_api', __name__) +multipart_api = Blueprint("multipart_api", __name__) multipart_header_start = "multipart/form-data; boundary=" @@ -20,23 +20,26 @@ # in requests, we see the file content through request.form # in aiohttp, we see the file through request.files -@multipart_api.route('/basic', methods=['POST']) + +@multipart_api.route("/basic", methods=["POST"]) def basic(): - assert_with_message("content type", multipart_header_start, request.content_type[:len(multipart_header_start)]) + assert_with_message("content type", multipart_header_start, request.content_type[: len(multipart_header_start)]) if request.files: # aiohttp assert_with_message("content length", 258, request.content_length) assert_with_message("num files", 1, len(request.files)) - assert_with_message("has file named fileContent", True, bool(request.files.get('fileContent'))) - file_content = request.files['fileContent'] + assert_with_message("has file named fileContent", True, bool(request.files.get("fileContent"))) + file_content = request.files["fileContent"] assert_with_message("file content type", "application/octet-stream", file_content.content_type) assert_with_message("file content length", 14, file_content.content_length) assert_with_message("filename", "fileContent", file_content.filename) - assert_with_message("has content disposition header", True, bool(file_content.headers.get("Content-Disposition"))) + assert_with_message( + "has content disposition header", True, bool(file_content.headers.get("Content-Disposition")) + ) assert_with_message( "content disposition", 'form-data; name="fileContent"; filename="fileContent"; filename*=utf-8\'\'fileContent', - file_content.headers["Content-Disposition"] + file_content.headers["Content-Disposition"], ) elif request.form: # requests @@ -46,31 +49,34 @@ def basic(): return Response(status=400) # should be either of these return Response(status=200) -@multipart_api.route('/data-and-files', methods=['POST']) + +@multipart_api.route("/data-and-files", methods=["POST"]) def data_and_files(): - assert_with_message("content type", multipart_header_start, request.content_type[:len(multipart_header_start)]) + assert_with_message("content type", multipart_header_start, request.content_type[: len(multipart_header_start)]) assert_with_message("message", "Hello, world!", request.form["message"]) assert_with_message("message", "", request.form["fileContent"]) return Response(status=200) -@multipart_api.route('/data-and-files-tuple', methods=['POST']) + +@multipart_api.route("/data-and-files-tuple", methods=["POST"]) def data_and_files_tuple(): - assert_with_message("content type", multipart_header_start, request.content_type[:len(multipart_header_start)]) + assert_with_message("content type", multipart_header_start, request.content_type[: len(multipart_header_start)]) assert_with_message("message", ["abc"], request.form["message"]) assert_with_message("message", [""], request.form["fileContent"]) return Response(status=200) -@multipart_api.route('/non-seekable-filelike', methods=['POST']) + +@multipart_api.route("/non-seekable-filelike", methods=["POST"]) def non_seekable_filelike(): - assert_with_message("content type", multipart_header_start, request.content_type[:len(multipart_header_start)]) + assert_with_message("content type", multipart_header_start, request.content_type[: len(multipart_header_start)]) if request.files: # aiohttp len_files = len(request.files) assert_with_message("num files", 1, len_files) # assert_with_message("content length", 258, request.content_length) assert_with_message("num files", 1, len(request.files)) - assert_with_message("has file named file", True, bool(request.files.get('file'))) - file = request.files['file'] + assert_with_message("has file named file", True, bool(request.files.get("file"))) + file = request.files["file"] assert_with_message("file content type", "application/octet-stream", file.content_type) assert_with_message("file content length", 14, file.content_length) assert_with_message("filename", "file", file.filename) @@ -78,7 +84,7 @@ def non_seekable_filelike(): assert_with_message( "content disposition", 'form-data; name="fileContent"; filename="fileContent"; filename*=utf-8\'\'fileContent', - file.headers["Content-Disposition"] + file.headers["Content-Disposition"], ) elif request.form: # requests @@ -87,7 +93,8 @@ def non_seekable_filelike(): return Response(status=400) return Response(status=200) -@multipart_api.route('/request', methods=["POST"]) + +@multipart_api.route("/request", methods=["POST"]) def multipart_request(): body_as_str = ( "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" @@ -115,5 +122,7 @@ def multipart_request(): "Time:2018-06-14T16:46:54.6040685Z\r\n" "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" ) - return Response(body_as_str.encode('ascii'), content_type="multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed") - + return Response( + body_as_str.encode("ascii"), + content_type="multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed", + ) diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py index 326f2dce63dc..005867a38e35 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py @@ -10,163 +10,162 @@ ) from .helpers import get_base_url, assert_with_message -polling_api = Blueprint('polling_api', __name__) +polling_api = Blueprint("polling_api", __name__) -@polling_api.route('/post/location-and-operation-location', methods=['POST']) +@polling_api.route("/post/location-and-operation-location", methods=["POST"]) def post_with_location_and_operation_location_initial(): base_url = get_base_url(request) return Response( '{"properties":{"provisioningState": "InProgress"}}', headers={ - 'location': '{}/polling/location-url'.format(base_url), - 'operation-location': '{}/polling/operation-location-url'.format(base_url), + "location": "{}/polling/location-url".format(base_url), + "operation-location": "{}/polling/operation-location-url".format(base_url), }, - status=202 + status=202, ) -@polling_api.route('/location-url', methods=['GET']) + +@polling_api.route("/location-url", methods=["GET"]) def location_url(): - return Response( - '{"location_result": true}', - status=200 - ) + return Response('{"location_result": true}', status=200) + -@polling_api.route('/location-no-body-url', methods=['GET']) +@polling_api.route("/location-no-body-url", methods=["GET"]) def location_no_body_url(): - return Response( - status=200 - ) + return Response(status=200) + -@polling_api.route('/operation-location-url', methods=['GET']) +@polling_api.route("/operation-location-url", methods=["GET"]) def operation_location_url(): - return Response( - '{"status": "Succeeded"}', - status=200 - ) + return Response('{"status": "Succeeded"}', status=200) -@polling_api.route('/post/location-and-operation-location-no-body', methods=['POST']) + +@polling_api.route("/post/location-and-operation-location-no-body", methods=["POST"]) def post_with_location_and_operation_location_initial_no_body(): base_url = get_base_url(request) return Response( '{"properties":{"provisioningState": "InProgress"}}', headers={ - 'location': '{}/polling/location-no-body-url'.format(base_url), - 'operation-location': '{}/polling/operation-location-url'.format(base_url), + "location": "{}/polling/location-no-body-url".format(base_url), + "operation-location": "{}/polling/operation-location-url".format(base_url), }, - status=202 + status=202, ) -@polling_api.route('/post/resource-location', methods=['POST']) + +@polling_api.route("/post/resource-location", methods=["POST"]) def resource_location(): base_url = get_base_url(request) return Response( - '', + "", status=202, headers={ - 'operation-location': '{}/polling/post/resource-location/operation-location-url'.format(base_url), - } + "operation-location": "{}/polling/post/resource-location/operation-location-url".format(base_url), + }, ) -@polling_api.route('/post/resource-location/operation-location-url', methods=['GET']) + +@polling_api.route("/post/resource-location/operation-location-url", methods=["GET"]) def resource_location_operation_location(): base_url = get_base_url(request) - resource_location = '{}/polling/location-url'.format(base_url) + resource_location = "{}/polling/location-url".format(base_url) return Response( '{"status": "Succeeded", "resourceLocation": "' + resource_location + '"}', status=200, ) -@polling_api.route('/no-polling', methods=['PUT']) + +@polling_api.route("/no-polling", methods=["PUT"]) def no_polling(): - return Response( - '{"properties":{"provisioningState": "Succeeded"}}', - status=201 - ) + return Response('{"properties":{"provisioningState": "Succeeded"}}', status=201) + -@polling_api.route('/operation-location', methods=["DELETE", "POST", "PUT", "PATCH", "GET"]) +@polling_api.route("/operation-location", methods=["DELETE", "POST", "PUT", "PATCH", "GET"]) def operation_location(): base_url = get_base_url(request) return Response( status=201, headers={ - 'operation-location': '{}/polling/operation-location-url'.format(base_url), - } + "operation-location": "{}/polling/operation-location-url".format(base_url), + }, ) -@polling_api.route('/bad-operation-location', methods=["PUT", "PATCH", "DELETE", "POST"]) + +@polling_api.route("/bad-operation-location", methods=["PUT", "PATCH", "DELETE", "POST"]) def bad_operation_location(): return Response( status=201, headers={ - 'operation-location': 'http://localhost:5000/does-not-exist', - } + "operation-location": "http://localhost:5000/does-not-exist", + }, ) -@polling_api.route('/location', methods=["PUT", "PATCH", "DELETE", "POST"]) + +@polling_api.route("/location", methods=["PUT", "PATCH", "DELETE", "POST"]) def location(): base_url = get_base_url(request) return Response( status=201, headers={ - 'location': '{}/polling/location-url'.format(base_url), - } + "location": "{}/polling/location-url".format(base_url), + }, ) -@polling_api.route('/bad-location', methods=["PUT", "PATCH", "DELETE", "POST"]) + +@polling_api.route("/bad-location", methods=["PUT", "PATCH", "DELETE", "POST"]) def bad_location(): return Response( status=201, headers={ - 'location': 'http://localhost:5000/does-not-exist', - } + "location": "http://localhost:5000/does-not-exist", + }, ) -@polling_api.route('/initial-body-invalid', methods=["PUT"]) + +@polling_api.route("/initial-body-invalid", methods=["PUT"]) def initial_body_invalid(): base_url = get_base_url(request) return Response( "", status=201, headers={ - 'location': '{}/polling/location-url'.format(base_url), - } + "location": "{}/polling/location-url".format(base_url), + }, ) -@polling_api.route('/request-id', methods=["POST"]) + +@polling_api.route("/request-id", methods=["POST"]) def request_id(): base_url = get_base_url(request) return Response( "", status=201, headers={ - 'location': '{}/polling/request-id-location'.format(base_url), - } + "location": "{}/polling/request-id-location".format(base_url), + }, ) -@polling_api.route('/request-id-location', methods=["GET"]) + +@polling_api.route("/request-id-location", methods=["GET"]) def request_id_location(): - assert_with_message("request id", request.headers['X-Ms-Client-Request-Id'], "123456789") - return Response( - '{"status": "Succeeded"}', - status=200 - ) + assert_with_message("request id", request.headers["X-Ms-Client-Request-Id"], "123456789") + return Response('{"status": "Succeeded"}', status=200) + -@polling_api.route('/polling-with-options', methods=["PUT"]) +@polling_api.route("/polling-with-options", methods=["PUT"]) def polling_with_options_first(): base_url = get_base_url(request) return Response( '{"properties":{"provisioningState": "InProgress"}}', headers={ - 'location': '{}/polling/final-get-with-location'.format(base_url), - 'operation-location': '{}/polling/post/resource-location/operation-location-url'.format(base_url), + "location": "{}/polling/final-get-with-location".format(base_url), + "operation-location": "{}/polling/post/resource-location/operation-location-url".format(base_url), }, - status=202 + status=202, ) -@polling_api.route('/final-get-with-location', methods=["GET"]) + +@polling_api.route("/final-get-with-location", methods=["GET"]) def polling_with_options_final_get_with_location(): - return Response( - '{"returnedFrom": "locationHeaderUrl"}', - status=200 - ) + return Response('{"returnedFrom": "locationHeaderUrl"}', status=200) diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/streams.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/streams.py index aeb20cae9e51..a8821b447956 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/streams.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/streams.py @@ -12,70 +12,82 @@ Blueprint, ) -streams_api = Blueprint('streams_api', __name__) +streams_api = Blueprint("streams_api", __name__) + class StreamingBody: def __iter__(self): yield b"Hello, " yield b"world!" + def streaming_body(): yield b"Hello, " yield b"world!" + def stream_json_error(): yield '{"error": {"code": "BadRequest", ' - yield' "message": "You made a bad request"}}' + yield ' "message": "You made a bad request"}}' + def streaming_test(): yield b"test" + def stream_compressed_header_error(): - yield b'test' + yield b"test" + def stream_compressed_no_header(): - with gzip.open('test.tar.gz', 'wb') as f: + with gzip.open("test.tar.gz", "wb") as f: f.write(b"test") - - with open(os.path.join(os.path.abspath('test.tar.gz')), "rb") as fd: + + with open(os.path.join(os.path.abspath("test.tar.gz")), "rb") as fd: yield fd.read() - + os.remove("test.tar.gz") - -@streams_api.route('/basic', methods=['GET']) + + +@streams_api.route("/basic", methods=["GET"]) def basic(): return Response(streaming_body(), status=200) -@streams_api.route('/iterable', methods=['GET']) + +@streams_api.route("/iterable", methods=["GET"]) def iterable(): return Response(StreamingBody(), status=200) -@streams_api.route('/error', methods=['GET']) + +@streams_api.route("/error", methods=["GET"]) def error(): return Response(stream_json_error(), status=400) -@streams_api.route('/string', methods=['GET']) + +@streams_api.route("/string", methods=["GET"]) def string(): - return Response( - streaming_test(), status=200, mimetype="text/plain" - ) + return Response(streaming_test(), status=200, mimetype="text/plain") + -@streams_api.route('/compressed_no_header', methods=['GET']) +@streams_api.route("/compressed_no_header", methods=["GET"]) def compressed_no_header(): return Response(stream_compressed_no_header(), status=300) -@streams_api.route('/compressed', methods=['GET']) + +@streams_api.route("/compressed", methods=["GET"]) def compressed(): return Response(stream_compressed_header_error(), status=300, headers={"Content-Encoding": "gzip"}) + def compressed_stream(): - with tempfile.TemporaryFile(mode='w+b') as f: - gzf = gzip.GzipFile(mode='w+b', fileobj=f) + with tempfile.TemporaryFile(mode="w+b") as f: + gzf = gzip.GzipFile(mode="w+b", fileobj=f) gzf.write(b"test") gzf.flush() f.seek(0) yield f.read() - -@streams_api.route('/decompress_header', methods=['GET']) + + +@streams_api.route("/decompress_header", methods=["GET"]) def decompress_header(): return Response(compressed_stream(), status=200, headers={"Content-Encoding": "gzip"}) diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/structures.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/structures.py index 1e443986cdf2..69ac916cc44f 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/structures.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/structures.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- + class CaseInsensitiveDict(dict): """Case-insensitive Dictionary for headers. diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/urlencoded.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/urlencoded.py index 4ea2bdd2795d..c9fb0167ce9c 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/urlencoded.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/urlencoded.py @@ -11,9 +11,10 @@ ) from .helpers import assert_with_message -urlencoded_api = Blueprint('urlencoded_api', __name__) +urlencoded_api = Blueprint("urlencoded_api", __name__) -@urlencoded_api.route('/pet/add/', methods=['POST']) + +@urlencoded_api.route("/pet/add/", methods=["POST"]) def basic(pet_id): assert_with_message("pet_id", "1", pet_id) assert_with_message("content type", "application/x-www-form-urlencoded", request.content_type) diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/xml_route.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/xml_route.py index c19aed97b6b5..a0ba8c25637c 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/xml_route.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/xml_route.py @@ -12,9 +12,10 @@ ) from .helpers import assert_with_message -xml_api = Blueprint('xml_api', __name__) +xml_api = Blueprint("xml_api", __name__) -@xml_api.route('/basic', methods=['GET', 'PUT']) + +@xml_api.route("/basic", methods=["GET", "PUT"]) def basic(): basic_body = """ """ - if request.method == 'GET': + if request.method == "GET": return Response(basic_body, status=200) - elif request.method == 'PUT': + elif request.method == "PUT": assert_with_message("content length", str(len(request.data)), request.headers["Content-Length"]) parsed_xml = ET.fromstring(request.data.decode("utf-8")) assert_with_message("tag", "slideshow", parsed_xml.tag) attributes = parsed_xml.attrib - assert_with_message("title attribute", "Sample Slide Show", attributes['title']) - assert_with_message("date attribute", "Date of publication", attributes['date']) - assert_with_message("author attribute", "Yours Truly", attributes['author']) + assert_with_message("title attribute", "Sample Slide Show", attributes["title"]) + assert_with_message("date attribute", "Date of publication", attributes["date"]) + assert_with_message("author attribute", "Yours Truly", attributes["author"]) return Response(status=200) return Response("You have passed in method '{}' that is not 'GET' or 'PUT'".format(request.method), status=400) diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/setup.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/setup.py index 63f51acde3f2..26366f2aa08f 100644 --- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/setup.py +++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/setup.py @@ -13,25 +13,25 @@ name="coretestserver", version=version, include_package_data=True, - description='Testserver for Python Core', - long_description='Testserver for Python Core', - license='MIT License', - author='Microsoft Corporation', - author_email='azpysdkhelp@microsoft.com', - url='https://github.com/iscai-msft/core.testserver', + description="Testserver for Python Core", + long_description="Testserver for Python Core", + license="MIT License", + author="Microsoft Corporation", + author_email="azpysdkhelp@microsoft.com", + url="https://github.com/iscai-msft/core.testserver", classifiers=[ - 'Development Status :: 4 - Beta', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: MIT License', + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: MIT License", ], packages=find_packages(), install_requires=[ "flask==1.1.4", - ] + ], ) diff --git a/sdk/core/azure-core/tests/tracing_common.py b/sdk/core/azure-core/tests/tracing_common.py index 2a2c55fe987f..e7720aba8f6f 100644 --- a/sdk/core/azure-core/tests/tracing_common.py +++ b/sdk/core/azure-core/tests/tracing_common.py @@ -6,6 +6,7 @@ from contextlib import contextmanager from azure.core.tracing import HttpSpanMixin, SpanKind from typing import Union, Sequence, Optional, Dict + AttributeValue = Union[ str, bool, @@ -18,6 +19,7 @@ ] Attributes = Optional[Dict[str, AttributeValue]] + class FakeSpan(HttpSpanMixin, object): # Keep a fake context of the current one CONTEXT = [] @@ -74,7 +76,6 @@ def kind(self): """Get the span kind of this span.""" return self._kind - @kind.setter def kind(self, value): # type: (SpanKind) -> None @@ -107,7 +108,7 @@ def to_header(self): Returns a dictionary with the header labels and values. :return: A key value pair dictionary """ - return {'traceparent': '123456789'} + return {"traceparent": "123456789"} def add_attribute(self, key, value): # type: (str, Union[str, int]) -> None @@ -134,7 +135,7 @@ def get_trace_parent(self): :return: a traceparent string :rtype: str """ - return self.to_header()['traceparent'] + return self.to_header()["traceparent"] @classmethod def link(cls, traceparent, attributes=None): @@ -145,9 +146,7 @@ def link(cls, traceparent, attributes=None): :param traceparent: A complete traceparent :type traceparent: str """ - cls.link_from_headers({ - 'traceparent': traceparent - }) + cls.link_from_headers({"traceparent": traceparent}) @classmethod def link_from_headers(cls, headers, attributes=None): @@ -180,8 +179,7 @@ def get_current_tracer(cls): @contextmanager def change_context(cls, span): # type: (Span) -> ContextManager - """Change the context for the life of this context manager. - """ + """Change the context for the life of this context manager.""" try: cls.CONTEXT.append(span) yield @@ -191,8 +189,7 @@ def change_context(cls, span): @classmethod def set_current_span(cls, span): # type: (Span) -> None - """Not supported by OpenTelemetry. - """ + """Not supported by OpenTelemetry.""" raise NotImplementedError() @classmethod diff --git a/sdk/core/azure-core/tests/utils.py b/sdk/core/azure-core/tests/utils.py index a70c5db98243..e1807a85e6cd 100644 --- a/sdk/core/azure-core/tests/utils.py +++ b/sdk/core/azure-core/tests/utils.py @@ -5,15 +5,18 @@ # ------------------------------------------------------------------------- import pytest import types + ############################## LISTS USED TO PARAMETERIZE TESTS ############################## from azure.core.rest import HttpRequest as RestHttpRequest from azure.core.pipeline.transport import HttpRequest as PipelineTransportHttpRequest from azure.core.pipeline._tools import is_rest + HTTP_REQUESTS = [PipelineTransportHttpRequest, RestHttpRequest] REQUESTS_TRANSPORT_RESPONSES = [] from azure.core.pipeline.transport import HttpResponse as PipelineTransportHttpResponse from azure.core.rest._http_response_impl import HttpResponseImpl as RestHttpResponse + HTTP_RESPONSES = [PipelineTransportHttpResponse, RestHttpResponse] ASYNC_HTTP_RESPONSES = [] @@ -21,6 +24,7 @@ try: from azure.core.pipeline.transport import AsyncHttpResponse as PipelineTransportAsyncHttpResponse from azure.core.rest._http_response_impl_async import AsyncHttpResponseImpl as RestAsyncHttpResponse + ASYNC_HTTP_RESPONSES = [PipelineTransportAsyncHttpResponse, RestAsyncHttpResponse] except (ImportError, SyntaxError): pass @@ -28,20 +32,30 @@ try: from azure.core.pipeline.transport import RequestsTransportResponse as PipelineTransportRequestsTransportResponse from azure.core.rest._requests_basic import RestRequestsTransportResponse + REQUESTS_TRANSPORT_RESPONSES = [PipelineTransportRequestsTransportResponse, RestRequestsTransportResponse] except ImportError: pass -from azure.core.pipeline.transport._base import HttpClientTransportResponse as PipelineTransportHttpClientTransportResponse +from azure.core.pipeline.transport._base import ( + HttpClientTransportResponse as PipelineTransportHttpClientTransportResponse, +) from azure.core.rest._http_response_impl import RestHttpClientTransportResponse + HTTP_CLIENT_TRANSPORT_RESPONSES = [PipelineTransportHttpClientTransportResponse, RestHttpClientTransportResponse] ASYNCIO_REQUESTS_TRANSPORT_RESPONSES = [] try: - from azure.core.pipeline.transport import AsyncioRequestsTransportResponse as PipelineTransportAsyncioRequestsTransportResponse + from azure.core.pipeline.transport import ( + AsyncioRequestsTransportResponse as PipelineTransportAsyncioRequestsTransportResponse, + ) from azure.core.rest._requests_asyncio import RestAsyncioRequestsTransportResponse - ASYNCIO_REQUESTS_TRANSPORT_RESPONSES = [PipelineTransportAsyncioRequestsTransportResponse, RestAsyncioRequestsTransportResponse] + + ASYNCIO_REQUESTS_TRANSPORT_RESPONSES = [ + PipelineTransportAsyncioRequestsTransportResponse, + RestAsyncioRequestsTransportResponse, + ] except (ImportError, SyntaxError): pass @@ -50,17 +64,20 @@ try: from azure.core.pipeline.transport import AioHttpTransportResponse as PipelineTransportAioHttpTransportResponse from azure.core.rest._aiohttp import RestAioHttpTransportResponse + AIOHTTP_TRANSPORT_RESPONSES = [PipelineTransportAioHttpTransportResponse, RestAioHttpTransportResponse] except (ImportError, SyntaxError): pass ############################## HELPER FUNCTIONS ############################## + def request_and_responses_product(*args): pipeline_transport = tuple([PipelineTransportHttpRequest]) + tuple(arg[0] for arg in args) rest = tuple([RestHttpRequest]) + tuple(arg[1] for arg in args) return [pipeline_transport, rest] + def create_http_request(http_request, *args, **kwargs): if hasattr(http_request, "content"): method = args[0] @@ -77,29 +94,19 @@ def create_http_request(http_request, *args, **kwargs): data = args[4] except IndexError: data = None - return http_request( - method=method, - url=url, - headers=headers, - files=files, - data=data, - **kwargs - ) + return http_request(method=method, url=url, headers=headers, files=files, data=data, **kwargs) return http_request(*args, **kwargs) + def create_transport_response(http_response, *args, **kwargs): # this creates transport-specific responses, # like requests responses / aiohttp responses if is_rest(http_response): block_size = args[2] if len(args) > 2 else None - return http_response( - request=args[0], - internal_response=args[1], - block_size=block_size, - **kwargs - ) + return http_response(request=args[0], internal_response=args[1], block_size=block_size, **kwargs) return http_response(*args, **kwargs) + def create_http_response(http_response, *args, **kwargs): # since the actual http_response object is # an ABC for our new responses, it's a little more @@ -121,13 +128,14 @@ def create_http_response(http_response, *args, **kwargs): ) return http_response(*args, **kwargs) + def readonly_checks(response, old_response_class): # though we want these properties to be completely readonly, it doesn't work # for the backcompat properties assert isinstance(response.request, RestHttpRequest) assert isinstance(response.status_code, int) assert response.headers - assert response.content_type == 'text/html; charset=utf-8' + assert response.content_type == "text/html; charset=utf-8" assert response.is_closed with pytest.raises(AttributeError): @@ -152,7 +160,7 @@ def readonly_checks(response, old_response_class): old_response = old_response_class(response.request, response.internal_response, response.block_size) for attr in dir(response): - if attr[0] == '_': + if attr[0] == "_": # don't care about private variables continue if type(getattr(response, attr)) == types.MethodType: diff --git a/sdk/core/azure-mgmt-core/azure/__init__.py b/sdk/core/azure-mgmt-core/azure/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-mgmt-core/azure/__init__.py +++ b/sdk/core/azure-mgmt-core/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/__init__.py b/sdk/core/azure-mgmt-core/azure/mgmt/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/__init__.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py index 289e82ea70a0..bc7bb190dabb 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/_async_pipeline_client.py @@ -46,18 +46,15 @@ class AsyncARMPipelineClient(AsyncPipelineClient): def __init__(self, base_url, **kwargs): if "policies" not in kwargs: if "config" not in kwargs: - raise ValueError( - "Current implementation requires to pass 'config' if you don't pass 'policies'" - ) - per_call_policies = kwargs.get('per_call_policies', []) + raise ValueError("Current implementation requires to pass 'config' if you don't pass 'policies'") + per_call_policies = kwargs.get("per_call_policies", []) if isinstance(per_call_policies, Iterable): per_call_policies.append(AsyncARMAutoResourceProviderRegistrationPolicy()) else: - per_call_policies = [per_call_policies, - AsyncARMAutoResourceProviderRegistrationPolicy()] + per_call_policies = [per_call_policies, AsyncARMAutoResourceProviderRegistrationPolicy()] kwargs["per_call_policies"] = per_call_policies - config = kwargs.get('config') + config = kwargs.get("config") if not config.http_logging_policy: - config.http_logging_policy = kwargs.get('http_logging_policy', ARMHttpLoggingPolicy(**kwargs)) + config.http_logging_policy = kwargs.get("http_logging_policy", ARMHttpLoggingPolicy(**kwargs)) kwargs["config"] = config super(AsyncARMPipelineClient, self).__init__(base_url, **kwargs) diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py index 4e80070758aa..7cefbf347562 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/_pipeline_client.py @@ -44,18 +44,15 @@ class ARMPipelineClient(PipelineClient): def __init__(self, base_url, **kwargs): if "policies" not in kwargs: if "config" not in kwargs: - raise ValueError( - "Current implementation requires to pass 'config' if you don't pass 'policies'" - ) - per_call_policies = kwargs.get('per_call_policies', []) + raise ValueError("Current implementation requires to pass 'config' if you don't pass 'policies'") + per_call_policies = kwargs.get("per_call_policies", []) if isinstance(per_call_policies, Iterable): per_call_policies.append(ARMAutoResourceProviderRegistrationPolicy()) else: - per_call_policies = [per_call_policies, - ARMAutoResourceProviderRegistrationPolicy()] + per_call_policies = [per_call_policies, ARMAutoResourceProviderRegistrationPolicy()] kwargs["per_call_policies"] = per_call_policies - config = kwargs.get('config') + config = kwargs.get("config") if not config.http_logging_policy: - config.http_logging_policy = kwargs.get('http_logging_policy', ARMHttpLoggingPolicy(**kwargs)) + config.http_logging_policy = kwargs.get("http_logging_policy", ARMHttpLoggingPolicy(**kwargs)) kwargs["config"] = config super(ARMPipelineClient, self).__init__(base_url, **kwargs) diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/__init__.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/__init__.py index 7ec2459658ef..1d786eebc1cd 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/__init__.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/__init__.py @@ -30,24 +30,26 @@ from ._authentication_async import AsyncARMChallengeAuthenticationPolicy from ._base_async import AsyncARMAutoResourceProviderRegistrationPolicy + class ARMHttpLoggingPolicy(HttpLoggingPolicy): - """HttpLoggingPolicy with ARM specific safe headers fopr loggers. - """ + """HttpLoggingPolicy with ARM specific safe headers fopr loggers.""" - DEFAULT_HEADERS_ALLOWLIST = HttpLoggingPolicy.DEFAULT_HEADERS_ALLOWLIST | set([ - # https://docs.microsoft.com/azure/azure-resource-manager/management/request-limits-and-throttling#remaining-requests - "x-ms-ratelimit-remaining-subscription-reads", - "x-ms-ratelimit-remaining-subscription-writes", - "x-ms-ratelimit-remaining-tenant-reads", - "x-ms-ratelimit-remaining-tenant-writes", - "x-ms-ratelimit-remaining-subscription-resource-requests", - "x-ms-ratelimit-remaining-subscription-resource-entities-read", - "x-ms-ratelimit-remaining-tenant-resource-requests", - "x-ms-ratelimit-remaining-tenant-resource-entities-read", - # https://docs.microsoft.com/azure/virtual-machines/troubleshooting/troubleshooting-throttling-errors#call-rate-informational-response-headers - "x-ms-ratelimit-remaining-resource", - "x-ms-request-charge", - ]) + DEFAULT_HEADERS_ALLOWLIST = HttpLoggingPolicy.DEFAULT_HEADERS_ALLOWLIST | set( + [ + # https://docs.microsoft.com/azure/azure-resource-manager/management/request-limits-and-throttling#remaining-requests + "x-ms-ratelimit-remaining-subscription-reads", + "x-ms-ratelimit-remaining-subscription-writes", + "x-ms-ratelimit-remaining-tenant-reads", + "x-ms-ratelimit-remaining-tenant-writes", + "x-ms-ratelimit-remaining-subscription-resource-requests", + "x-ms-ratelimit-remaining-subscription-resource-entities-read", + "x-ms-ratelimit-remaining-tenant-resource-requests", + "x-ms-ratelimit-remaining-tenant-resource-entities-read", + # https://docs.microsoft.com/azure/virtual-machines/troubleshooting/troubleshooting-throttling-errors#call-rate-informational-response-headers + "x-ms-ratelimit-remaining-resource", + "x-ms-request-charge", + ] + ) __all__ = [ @@ -55,5 +57,5 @@ class ARMHttpLoggingPolicy(HttpLoggingPolicy): "ARMChallengeAuthenticationPolicy", "ARMHttpLoggingPolicy", "AsyncARMAutoResourceProviderRegistrationPolicy", - "AsyncARMChallengeAuthenticationPolicy" + "AsyncARMChallengeAuthenticationPolicy", ] diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/_base.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/_base.py index dd19692432b7..8f00b976a566 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/_base.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/_base.py @@ -38,8 +38,7 @@ class ARMAutoResourceProviderRegistrationPolicy(HTTPPolicy): - """Auto register an ARM resource provider if not done yet. - """ + """Auto register an ARM resource provider if not done yet.""" def send(self, request): # type: (PipelineRequest[HTTPRequestType], Any) -> PipelineResponse[HTTPRequestType, HTTPResponseType] @@ -82,9 +81,7 @@ def _extract_subscription_url(url): @staticmethod def _build_next_request(initial_request, method, url): request = HttpRequest(method, url) - context = PipelineContext( - initial_request.context.transport, **initial_request.context.options - ) + context = PipelineContext(initial_request.context.transport, **initial_request.context.options) return PipelineRequest(request, context) def _register_rp(self, initial_request, url_prefix, rp_name): @@ -92,27 +89,20 @@ def _register_rp(self, initial_request, url_prefix, rp_name): Return False if we have a reason to believe this didn't work """ - post_url = "{}providers/{}/register?api-version=2016-02-01".format( - url_prefix, rp_name - ) + post_url = "{}providers/{}/register?api-version=2016-02-01".format(url_prefix, rp_name) get_url = "{}providers/{}?api-version=2016-02-01".format(url_prefix, rp_name) _LOGGER.warning( - "Resource provider '%s' used by this operation is not " - "registered. We are registering for you.", + "Resource provider '%s' used by this operation is not " "registered. We are registering for you.", rp_name, ) - post_response = self.next.send( - self._build_next_request(initial_request, "POST", post_url) - ) + post_response = self.next.send(self._build_next_request(initial_request, "POST", post_url)) if post_response.http_response.status_code != 200: _LOGGER.warning("Registration failed. Please register manually.") return False while True: time.sleep(10) - get_response = self.next.send( - self._build_next_request(initial_request, "GET", get_url) - ) + get_response = self.next.send(self._build_next_request(initial_request, "GET", get_url)) rp_info = json.loads(get_response.http_response.text()) if rp_info["registrationState"] == "Registered": _LOGGER.warning("Registration succeeded.") diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/_base_async.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/_base_async.py index ce6c58b27f0f..6208df12e5ff 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/_base_async.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/policies/_base_async.py @@ -36,11 +36,8 @@ _LOGGER = logging.getLogger(__name__) -class AsyncARMAutoResourceProviderRegistrationPolicy( - ARMAutoResourceProviderRegistrationPolicy, AsyncHTTPPolicy -): - """Auto register an ARM resource provider if not done yet. - """ +class AsyncARMAutoResourceProviderRegistrationPolicy(ARMAutoResourceProviderRegistrationPolicy, AsyncHTTPPolicy): + """Auto register an ARM resource provider if not done yet.""" async def send(self, request: PipelineRequest): # pylint: disable=invalid-overridden-method http_request = request.http_request @@ -49,9 +46,7 @@ async def send(self, request: PipelineRequest): # pylint: disable=invalid-overr rp_name = self._check_rp_not_registered_err(response) if rp_name: url_prefix = self._extract_subscription_url(http_request.url) - register_rp_status = await self._async_register_rp( - request, url_prefix, rp_name - ) + register_rp_status = await self._async_register_rp(request, url_prefix, rp_name) if not register_rp_status: return response # Change the 'x-ms-client-request-id' otherwise the Azure endpoint @@ -66,27 +61,20 @@ async def _async_register_rp(self, initial_request, url_prefix, rp_name): Return False if we have a reason to believe this didn't work """ - post_url = "{}providers/{}/register?api-version=2016-02-01".format( - url_prefix, rp_name - ) + post_url = "{}providers/{}/register?api-version=2016-02-01".format(url_prefix, rp_name) get_url = "{}providers/{}?api-version=2016-02-01".format(url_prefix, rp_name) _LOGGER.warning( - "Resource provider '%s' used by this operation is not " - "registered. We are registering for you.", + "Resource provider '%s' used by this operation is not " "registered. We are registering for you.", rp_name, ) - post_response = await self.next.send( - self._build_next_request(initial_request, "POST", post_url) - ) + post_response = await self.next.send(self._build_next_request(initial_request, "POST", post_url)) if post_response.http_response.status_code != 200: _LOGGER.warning("Registration failed. Please register manually.") return False while True: await asyncio.sleep(10) - get_response = await self.next.send( - self._build_next_request(initial_request, "GET", get_url) - ) + get_response = await self.next.send(self._build_next_request(initial_request, "GET", get_url)) rp_info = json.loads(get_response.http_response.text()) if rp_info["registrationState"] == "Registered": _LOGGER.warning("Registration succeeded.") diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/polling/arm_polling.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/polling/arm_polling.py index f450403e294e..0353a7d3d85c 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/polling/arm_polling.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/polling/arm_polling.py @@ -65,13 +65,10 @@ class _FinalStateViaOption(str, Enum, metaclass=CaseInsensitiveEnumMeta): class AzureAsyncOperationPolling(OperationResourcePolling): - """Implements a operation resource polling, typically from Azure-AsyncOperation. - """ + """Implements a operation resource polling, typically from Azure-AsyncOperation.""" def __init__(self, lro_options=None): - super(AzureAsyncOperationPolling, self).__init__( - operation_location_header="azure-asyncoperation" - ) + super(AzureAsyncOperationPolling, self).__init__(operation_location_header="azure-asyncoperation") self._lro_options = lro_options or {} @@ -82,14 +79,11 @@ def get_final_get_url(self, pipeline_response): :rtype: str """ if ( - self._lro_options.get(_LroOption.FINAL_STATE_VIA) - == _FinalStateViaOption.AZURE_ASYNC_OPERATION_FINAL_STATE + self._lro_options.get(_LroOption.FINAL_STATE_VIA) == _FinalStateViaOption.AZURE_ASYNC_OPERATION_FINAL_STATE and self._request.method == "POST" ): return None - return super(AzureAsyncOperationPolling, self).get_final_get_url( - pipeline_response - ) + return super(AzureAsyncOperationPolling, self).get_final_get_url(pipeline_response) class BodyContentPolling(LongRunningOperation): @@ -103,15 +97,13 @@ def __init__(self): def can_poll(self, pipeline_response): # type: (PipelineResponseType) -> bool - """Answer if this polling method could be used. - """ + """Answer if this polling method could be used.""" response = pipeline_response.http_response return response.request.method in ["PUT", "PATCH"] def get_polling_url(self): # type: () -> str - """Return the polling URL. - """ + """Return the polling URL.""" return self._initial_response.http_response.request.url def get_final_get_url(self, pipeline_response): @@ -167,9 +159,7 @@ def get_status(self, pipeline_response): """ response = pipeline_response.http_response if _is_empty(response): - raise BadResponse( - "The response from long running operation does not contain a body." - ) + raise BadResponse("The response from long running operation does not contain a body.") status = self._get_provisioning_state(response) return status or "Succeeded" @@ -177,12 +167,7 @@ def get_status(self, pipeline_response): class ARMPolling(LROBasePolling): def __init__( - self, - timeout=30, - lro_algorithms=None, - lro_options=None, - path_format_arguments=None, - **operation_config + self, timeout=30, lro_algorithms=None, lro_options=None, path_format_arguments=None, **operation_config ): lro_algorithms = lro_algorithms or [ AzureAsyncOperationPolling(lro_options=lro_options), @@ -198,8 +183,9 @@ def __init__( **operation_config ) + __all__ = [ - 'AzureAsyncOperationPolling', - 'BodyContentPolling', - 'ARMPolling', + "AzureAsyncOperationPolling", + "BodyContentPolling", + "ARMPolling", ] diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/polling/async_arm_polling.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/polling/async_arm_polling.py index 211322450df1..c79ea88df49d 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/polling/async_arm_polling.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/polling/async_arm_polling.py @@ -32,4 +32,5 @@ class AsyncARMPolling(ARMPolling, AsyncLROBasePolling): pass + __all__ = ["AsyncARMPolling"] diff --git a/sdk/core/azure-mgmt-core/azure/mgmt/core/tools.py b/sdk/core/azure-mgmt-core/azure/mgmt/core/tools.py index c5d8de343e22..7cbcea51436a 100644 --- a/sdk/core/azure-mgmt-core/azure/mgmt/core/tools.py +++ b/sdk/core/azure-mgmt-core/azure/mgmt/core/tools.py @@ -34,10 +34,7 @@ "(/providers/(?P[^/]+)/(?P[^/]*)/(?P[^/]+)(?P.*))?" ) -_CHILDREN_RE = re.compile( - "(?i)(/providers/(?P[^/]+))?/" - "(?P[^/]*)/(?P[^/]+)" -) +_CHILDREN_RE = re.compile("(?i)(/providers/(?P[^/]+))?/" "(?P[^/]*)/(?P[^/]+)") _ARMNAME_RE = re.compile("^[^<>%&:\\?/]{1,260}$") @@ -85,12 +82,7 @@ def parse_resource_id(rid): children = _CHILDREN_RE.finditer(result["children"] or "") count = None for count, child in enumerate(children): - result.update( - { - key + "_%d" % (count + 1): group - for key, group in child.groupdict().items() - } - ) + result.update({key + "_%d" % (count + 1): group for key, group in child.groupdict().items()}) result["last_child_num"] = count + 1 if isinstance(count, int) else None result = _populate_alternate_kwargs(result) else: @@ -99,17 +91,13 @@ def parse_resource_id(rid): def _populate_alternate_kwargs(kwargs): - """ Translates the parsed arguments into a format used by generic ARM commands + """Translates the parsed arguments into a format used by generic ARM commands such as the resource and lock commands. """ resource_namespace = kwargs["namespace"] - resource_type = ( - kwargs.get("child_type_{}".format(kwargs["last_child_num"])) or kwargs["type"] - ) - resource_name = ( - kwargs.get("child_name_{}".format(kwargs["last_child_num"])) or kwargs["name"] - ) + resource_type = kwargs.get("child_type_{}".format(kwargs["last_child_num"])) or kwargs["type"] + resource_name = kwargs.get("child_name_{}".format(kwargs["last_child_num"])) or kwargs["name"] _get_parents_from_parts(kwargs) kwargs["resource_namespace"] = resource_namespace @@ -119,8 +107,7 @@ def _populate_alternate_kwargs(kwargs): def _get_parents_from_parts(kwargs): - """ Get the parents given all the children parameters. - """ + """Get the parents given all the children parameters.""" parent_builder = [] if kwargs["last_child_num"] is not None: parent_builder.append("{type}/{name}/".format(**kwargs)) @@ -129,17 +116,11 @@ def _get_parents_from_parts(kwargs): if child_namespace is not None: parent_builder.append("providers/{}/".format(child_namespace)) kwargs["child_parent_{}".format(index)] = "".join(parent_builder) - parent_builder.append( - "{{child_type_{0}}}/{{child_name_{0}}}/".format(index).format(**kwargs) - ) - child_namespace = kwargs.get( - "child_namespace_{}".format(kwargs["last_child_num"]) - ) + parent_builder.append("{{child_type_{0}}}/{{child_name_{0}}}/".format(index).format(**kwargs)) + child_namespace = kwargs.get("child_namespace_{}".format(kwargs["last_child_num"])) if child_namespace is not None: parent_builder.append("providers/{}/".format(child_namespace)) - kwargs["child_parent_{}".format(kwargs["last_child_num"])] = "".join( - parent_builder - ) + kwargs["child_parent_{}".format(kwargs["last_child_num"])] = "".join(parent_builder) kwargs["resource_parent"] = "".join(parent_builder) if kwargs["name"] else None return kwargs @@ -178,14 +159,10 @@ def resource_id(**kwargs): count = 1 while True: try: - rid_builder.append( - "providers/{{child_namespace_{}}}".format(count).format(**kwargs) - ) + rid_builder.append("providers/{{child_namespace_{}}}".format(count).format(**kwargs)) except KeyError: pass - rid_builder.append( - "{{child_type_{0}}}/{{child_name_{0}}}".format(count).format(**kwargs) - ) + rid_builder.append("{{child_type_{0}}}/{{child_name_{0}}}".format(count).format(**kwargs)) count += 1 except KeyError: pass diff --git a/sdk/core/azure-mgmt-core/setup.py b/sdk/core/azure-mgmt-core/setup.py index 9cad2b62ccab..40b05bc62615 100644 --- a/sdk/core/azure-mgmt-core/setup.py +++ b/sdk/core/azure-mgmt-core/setup.py @@ -1,10 +1,10 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import re import os.path @@ -16,55 +16,56 @@ PACKAGE_PPRINT_NAME = "Management Core" # a-b-c => a/b/c -package_folder_path = PACKAGE_NAME.replace('-', '/') +package_folder_path = PACKAGE_NAME.replace("-", "/") # a-b-c => a.b.c -namespace_name = PACKAGE_NAME.replace('-', '.') +namespace_name = PACKAGE_NAME.replace("-", ".") # Version extraction inspired from 'requests' -with open(os.path.join(package_folder_path, '_version.py'), 'r') as fd: - version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', # type: ignore - fd.read(), re.MULTILINE).group(1) +with open(os.path.join(package_folder_path, "_version.py"), "r") as fd: + version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) # type: ignore if not version: - raise RuntimeError('Cannot find version information') + raise RuntimeError("Cannot find version information") -with open('README.md', encoding='utf-8') as f: +with open("README.md", encoding="utf-8") as f: readme = f.read() -with open('CHANGELOG.md', encoding='utf-8') as f: +with open("CHANGELOG.md", encoding="utf-8") as f: changelog = f.read() setup( name=PACKAGE_NAME, version=version, include_package_data=True, - description='Microsoft Azure {} Library for Python'.format(PACKAGE_PPRINT_NAME), - long_description=readme + '\n\n' + changelog, - long_description_content_type='text/markdown', - license='MIT License', - author='Microsoft Corporation', - author_email='azpysdkhelp@microsoft.com', - url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-mgmt-core', + description="Microsoft Azure {} Library for Python".format(PACKAGE_PPRINT_NAME), + long_description=readme + "\n\n" + changelog, + long_description_content_type="text/markdown", + license="MIT License", + author="Microsoft Corporation", + author_email="azpysdkhelp@microsoft.com", + url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/core/azure-mgmt-core", classifiers=[ "Development Status :: 5 - Production/Stable", - 'Programming Language :: Python', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: MIT License', + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: MIT License", ], zip_safe=False, - packages=find_packages(exclude=[ - 'tests', - # Exclude packages that will be covered by PEP420 or nspkg - 'azure', - 'azure.mgmt', - ]), + packages=find_packages( + exclude=[ + "tests", + # Exclude packages that will be covered by PEP420 or nspkg + "azure", + "azure.mgmt", + ] + ), package_data={ - 'pytyped': ['py.typed'], + "pytyped": ["py.typed"], }, install_requires=[ "azure-core<2.0.0,>=1.26.2", diff --git a/sdk/core/azure-mgmt-core/tests/asynctests/test_async_arm_polling.py b/sdk/core/azure-mgmt-core/tests/asynctests/test_async_arm_polling.py index 74150fd7009f..4bf91ab6a78b 100644 --- a/sdk/core/azure-mgmt-core/tests/asynctests/test_async_arm_polling.py +++ b/sdk/core/azure-mgmt-core/tests/asynctests/test_async_arm_polling.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import base64 import json import pickle @@ -39,11 +39,7 @@ from azure.core.pipeline import PipelineResponse, AsyncPipeline from azure.core.pipeline.transport import AsyncioRequestsTransportResponse, AsyncHttpTransport -from azure.core.polling.base_polling import ( - LongRunningOperation, - BadStatus, - LocationPolling -) +from azure.core.polling.base_polling import LongRunningOperation, BadStatus, LocationPolling from azure.mgmt.core.polling.async_arm_polling import ( AsyncARMPolling, ) @@ -66,23 +62,29 @@ def __repr__(self): def __eq__(self, other): return self.__dict__ == other.__dict__ + class BadEndpointError(Exception): pass -TEST_NAME = 'foo' -RESPONSE_BODY = {'properties':{'provisioningState': 'InProgress'}} -ASYNC_BODY = json.dumps({ 'status': 'Succeeded' }) -ASYNC_URL = 'http://dummyurlFromAzureAsyncOPHeader_Return200' -LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) -LOCATION_URL = 'http://dummyurlurlFromLocationHeader_Return200' -RESOURCE_BODY = json.dumps({ 'name': TEST_NAME }) -RESOURCE_URL = 'http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1' -ERROR = 'http://dummyurl_ReturnError' + +TEST_NAME = "foo" +RESPONSE_BODY = {"properties": {"provisioningState": "InProgress"}} +ASYNC_BODY = json.dumps({"status": "Succeeded"}) +ASYNC_URL = "http://dummyurlFromAzureAsyncOPHeader_Return200" +LOCATION_BODY = json.dumps({"name": TEST_NAME}) +LOCATION_URL = "http://dummyurlurlFromLocationHeader_Return200" +RESOURCE_BODY = json.dumps({"name": TEST_NAME}) +RESOURCE_URL = "http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1" +ERROR = "http://dummyurl_ReturnError" POLLING_STATUS = 200 CLIENT = AsyncPipelineClient("http://example.org") + + async def mock_run(client_self, request, **kwargs): return TestArmPolling.mock_update(request.url) + + CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -92,21 +94,23 @@ def async_pipeline_client_builder(): send will receive "request" and kwargs as any transport layer """ + def create_client(send_cb): class TestHttpTransport(AsyncHttpTransport): - async def open(self): pass - async def close(self): pass - async def __aexit__(self, *args, **kwargs): pass + async def open(self): + pass + + async def close(self): + pass + + async def __aexit__(self, *args, **kwargs): + pass async def send(self, request, **kwargs): return await send_cb(request, **kwargs) - return AsyncPipelineClient( - 'http://example.org/', - pipeline=AsyncPipeline( - transport=TestHttpTransport() - ) - ) + return AsyncPipelineClient("http://example.org/", pipeline=AsyncPipeline(transport=TestHttpTransport())) + return create_client @@ -114,106 +118,84 @@ async def send(self, request, **kwargs): def deserialization_cb(): def cb(pipeline_response): return json.loads(pipeline_response.http_response.text()) + return cb @pytest.mark.asyncio async def test_post(async_pipeline_client_builder, deserialization_cb): - # Test POST LRO with both Location and Azure-AsyncOperation - - # The initial response contains both Location and Azure-AsyncOperation, a 202 and no Body - initial_response = TestArmPolling.mock_send( - 'POST', - 202, - { - 'location': 'http://example.org/location', - 'azure-asyncoperation': 'http://example.org/async_monitor', - }, - '' - ) + # Test POST LRO with both Location and Azure-AsyncOperation + + # The initial response contains both Location and Azure-AsyncOperation, a 202 and no Body + initial_response = TestArmPolling.mock_send( + "POST", + 202, + { + "location": "http://example.org/location", + "azure-asyncoperation": "http://example.org/async_monitor", + }, + "", + ) + + async def send(request, **kwargs): + assert request.method == "GET" - async def send(request, **kwargs): - assert request.method == 'GET' - - if request.url == 'http://example.org/location': - return TestArmPolling.mock_send( - 'GET', - 200, - body={'location_result': True} - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestArmPolling.mock_send( - 'GET', - 200, - body={'status': 'Succeeded'} - ).http_response - else: - pytest.fail("No other query allowed") - - client = async_pipeline_client_builder(send) - - # Test 1, LRO options with Location final state - poll = async_poller( - client, - initial_response, - deserialization_cb, - AsyncARMPolling(0, lro_options={"final-state-via": "location"})) - result = await poll - assert result['location_result'] == True - - # Test 2, LRO options with Azure-AsyncOperation final state - poll = async_poller( - client, - initial_response, - deserialization_cb, - AsyncARMPolling(0, lro_options={"final-state-via": "azure-async-operation"})) - result = await poll - assert result['status'] == 'Succeeded' - - # Test 3, "do the right thing" and use Location by default - poll = async_poller( - client, - initial_response, - deserialization_cb, - AsyncARMPolling(0)) - result = await poll - assert result['location_result'] == True - - # Test 4, location has no body - - async def send(request, **kwargs): - assert request.method == 'GET' - - if request.url == 'http://example.org/location': - return TestArmPolling.mock_send( - 'GET', - 200, - body=None - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestArmPolling.mock_send( - 'GET', - 200, - body={'status': 'Succeeded'} - ).http_response - else: - pytest.fail("No other query allowed") - - client = async_pipeline_client_builder(send) - - poll = async_poller( - client, - initial_response, - deserialization_cb, - AsyncARMPolling(0, lro_options={"final-state-via": "location"})) - result = await poll - assert result is None + if request.url == "http://example.org/location": + return TestArmPolling.mock_send("GET", 200, body={"location_result": True}).http_response + elif request.url == "http://example.org/async_monitor": + return TestArmPolling.mock_send("GET", 200, body={"status": "Succeeded"}).http_response + else: + pytest.fail("No other query allowed") + + client = async_pipeline_client_builder(send) + + # Test 1, LRO options with Location final state + poll = async_poller( + client, initial_response, deserialization_cb, AsyncARMPolling(0, lro_options={"final-state-via": "location"}) + ) + result = await poll + assert result["location_result"] == True + + # Test 2, LRO options with Azure-AsyncOperation final state + poll = async_poller( + client, + initial_response, + deserialization_cb, + AsyncARMPolling(0, lro_options={"final-state-via": "azure-async-operation"}), + ) + result = await poll + assert result["status"] == "Succeeded" + + # Test 3, "do the right thing" and use Location by default + poll = async_poller(client, initial_response, deserialization_cb, AsyncARMPolling(0)) + result = await poll + assert result["location_result"] == True + + # Test 4, location has no body + + async def send(request, **kwargs): + assert request.method == "GET" + + if request.url == "http://example.org/location": + return TestArmPolling.mock_send("GET", 200, body=None).http_response + elif request.url == "http://example.org/async_monitor": + return TestArmPolling.mock_send("GET", 200, body={"status": "Succeeded"}).http_response + else: + pytest.fail("No other query allowed") + + client = async_pipeline_client_builder(send) + + poll = async_poller( + client, initial_response, deserialization_cb, AsyncARMPolling(0, lro_options={"final-state-via": "location"}) + ) + result = await poll + assert result is None class TestArmPolling(object): - convert = re.compile('([a-z0-9])([A-Z])') + convert = re.compile("([a-z0-9])([A-Z])") @staticmethod def mock_send(method, status, headers=None, body=RESPONSE_BODY): @@ -221,13 +203,11 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): headers = {} response = Response() response._content_consumed = True - response._content = json.dumps(body).encode('ascii') if body is not None else None + response._content = json.dumps(body).encode("ascii") if body is not None else None response.request = Request() response.request.method = method response.request.url = RESOURCE_URL - response.request.headers = { - 'x-ms-client-request-id': '67f4dd4e-6262-45e1-8bed-5c45cf23b6d9' - } + response.request.headers = {"x-ms-client-request-id": "67f4dd4e-6262-45e1-8bed-5c45cf23b6d9"} response.status_code = status response.headers = headers response.headers.update({"content-type": "application/json; charset=utf8"}) @@ -240,7 +220,7 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): response.request.headers, body, None, # form_content - None # stream_content + None, # stream_content ) return PipelineResponse( @@ -249,7 +229,7 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): request, response, ), - None # context + None, # context ) @staticmethod @@ -257,7 +237,7 @@ def mock_update(url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) - response.request.method = 'GET' + response.request.method = "GET" response.headers = headers or {} response.headers.update({"content-type": "application/json; charset=utf8"}) response.reason = "OK" @@ -265,13 +245,13 @@ def mock_update(url, headers=None): if url == ASYNC_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = ASYNC_BODY.encode('ascii') + response._content = ASYNC_BODY.encode("ascii") response.randomFieldFromPollAsyncOpHeader = None elif url == LOCATION_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = LOCATION_BODY.encode('ascii') + response._content = LOCATION_BODY.encode("ascii") response.randomFieldFromPollLocationHeader = None elif url == ERROR: @@ -280,19 +260,19 @@ def mock_update(url, headers=None): elif url == RESOURCE_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = RESOURCE_BODY.encode('ascii') + response._content = RESOURCE_BODY.encode("ascii") else: - raise Exception('URL does not match') + raise Exception("URL does not match") request = CLIENT._request( response.request.method, response.request.url, None, # params - {}, # request has no headers - None, # Request has no body + {}, # request has no headers + None, # Request has no body None, # form_content - None # stream_content + None, # stream_content ) return PipelineResponse( @@ -301,7 +281,7 @@ def mock_update(url, headers=None): request, response, ), - None # context + None, # context ) @staticmethod @@ -312,15 +292,13 @@ def mock_outputs(pipeline_response): except ValueError: raise DecodeError("Impossible to deserialize") - body = {TestArmPolling.convert.sub(r'\1_\2', k).lower(): v - for k, v in body.items()} - properties = body.setdefault('properties', {}) - if 'name' in body: - properties['name'] = body['name'] + body = {TestArmPolling.convert.sub(r"\1_\2", k).lower(): v for k, v in body.items()} + properties = body.setdefault("properties", {}) + if "name" in body: + properties["name"] = body["name"] if properties: - properties = {TestArmPolling.convert.sub(r'\1_\2', k).lower(): v - for k, v in properties.items()} - del body['properties'] + properties = {TestArmPolling.convert.sub(r"\1_\2", k).lower(): v for k, v in properties.items()} + del body["properties"] body.update(properties) resource = SimpleResource(**body) else: @@ -330,230 +308,162 @@ def mock_outputs(pipeline_response): @staticmethod def mock_deserialization_no_body(pipeline_response): - """Use this mock when you don't expect a return (last body irrelevant) - """ + """Use this mock when you don't expect a return (last body irrelevant)""" return None + @pytest.mark.asyncio async def test_long_running_put(): - #TODO: Test custom header field + # TODO: Test custom header field # Test throw on non LRO related status code - response = TestArmPolling.mock_send('PUT', 1000, {}) + response = TestArmPolling.mock_send("PUT", 1000, {}) with pytest.raises(HttpResponseError): - await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) + await async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) # Test with no polling necessary - response_body = { - 'properties':{'provisioningState': 'Succeeded'}, - 'name': TEST_NAME - } - response = TestArmPolling.mock_send( - 'PUT', 201, - {}, response_body - ) + response_body = {"properties": {"provisioningState": "Succeeded"}, "name": TEST_NAME} + response = TestArmPolling.mock_send("PUT", 201, {}, response_body) + def no_update_allowed(url, headers=None): raise ValueError("Should not try to update") + polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method - ) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME - assert not hasattr(polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'PUT', 201, - {'azure-asyncoperation': ASYNC_URL}) + response = TestArmPolling.mock_send("PUT", 201, {"azure-asyncoperation": ASYNC_URL}) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME - assert not hasattr(polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling location header - response = TestArmPolling.mock_send( - 'PUT', 201, - {'location': LOCATION_URL}) + response = TestArmPolling.mock_send("PUT", 201, {"location": LOCATION_URL}) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None # Test polling initial payload invalid (SQLDb) response_body = {} # Empty will raise - response = TestArmPolling.mock_send( - 'PUT', 201, - {'location': LOCATION_URL}, response_body) + response = TestArmPolling.mock_send("PUT", 201, {"location": LOCATION_URL}, response_body) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None # Test fail to poll from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'PUT', 201, - {'azure-asyncoperation': ERROR}) + response = TestArmPolling.mock_send("PUT", 201, {"azure-asyncoperation": ERROR}) with pytest.raises(BadEndpointError): - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) # Test fail to poll from location header - response = TestArmPolling.mock_send( - 'PUT', 201, - {'location': ERROR}) + response = TestArmPolling.mock_send("PUT", 201, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) + @pytest.mark.asyncio async def test_long_running_patch(): # Test polling from location header response = TestArmPolling.mock_send( - 'PATCH', 202, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "PATCH", 202, {"location": LOCATION_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None # Test polling from azure-asyncoperation header response = TestArmPolling.mock_send( - 'PATCH', 202, - {'azure-asyncoperation': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "PATCH", 202, {"azure-asyncoperation": ASYNC_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME - assert not hasattr(polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling from location header response = TestArmPolling.mock_send( - 'PATCH', 200, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "PATCH", 200, {"location": LOCATION_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None # Test polling from azure-asyncoperation header response = TestArmPolling.mock_send( - 'PATCH', 200, - {'azure-asyncoperation': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "PATCH", 200, {"azure-asyncoperation": ASYNC_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME - assert not hasattr(polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test fail to poll from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'PATCH', 202, - {'azure-asyncoperation': ERROR}) + response = TestArmPolling.mock_send("PATCH", 202, {"azure-asyncoperation": ERROR}) with pytest.raises(BadEndpointError): - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) # Test fail to poll from location header - response = TestArmPolling.mock_send( - 'PATCH', 202, - {'location': ERROR}) + response = TestArmPolling.mock_send("PATCH", 202, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) + @pytest.mark.asyncio async def test_long_running_delete(): # Test polling from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'DELETE', 202, - {'azure-asyncoperation': ASYNC_URL}, - body="" - ) + response = TestArmPolling.mock_send("DELETE", 202, {"azure-asyncoperation": ASYNC_URL}, body="") polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_deserialization_no_body, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_deserialization_no_body, polling_method) assert poll is None assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None + @pytest.mark.asyncio async def test_long_running_post(): # Test polling from azure-asyncoperation header response = TestArmPolling.mock_send( - 'POST', 201, - {'azure-asyncoperation': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "POST", 201, {"azure-asyncoperation": ASYNC_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_deserialization_no_body, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_deserialization_no_body, polling_method) assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None # Test polling from azure-asyncoperation header response = TestArmPolling.mock_send( - 'POST', 202, - {'azure-asyncoperation': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "POST", 202, {"azure-asyncoperation": ASYNC_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_deserialization_no_body, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_deserialization_no_body, polling_method) assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None # Test polling from location header response = TestArmPolling.mock_send( - 'POST', 202, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) + "POST", 202, {"location": LOCATION_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) polling_method = AsyncARMPolling(0) - poll = await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - polling_method) + poll = await async_poller(CLIENT, response, TestArmPolling.mock_outputs, polling_method) assert poll.name == TEST_NAME assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None # Test fail to poll from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'POST', 202, - {'azure-asyncoperation': ERROR}) + response = TestArmPolling.mock_send("POST", 202, {"azure-asyncoperation": ERROR}) with pytest.raises(BadEndpointError): - await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) + await async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) # Test fail to poll from location header - response = TestArmPolling.mock_send( - 'POST', 202, - {'location': ERROR}) + response = TestArmPolling.mock_send("POST", 202, {"location": ERROR}) with pytest.raises(BadEndpointError): - await async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) + await async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) + @pytest.mark.asyncio async def test_long_running_negative(): @@ -561,52 +471,37 @@ async def test_long_running_negative(): global POLLING_STATUS # Test LRO PUT throws for invalid json - LOCATION_BODY = '{' - response = TestArmPolling.mock_send( - 'POST', 202, - {'location': LOCATION_URL}) - poll = async_poller( - CLIENT, - response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0) - ) + LOCATION_BODY = "{" + response = TestArmPolling.mock_send("POST", 202, {"location": LOCATION_URL}) + poll = async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) with pytest.raises(DecodeError): await poll - LOCATION_BODY = '{\'"}' - response = TestArmPolling.mock_send( - 'POST', 202, - {'location': LOCATION_URL}) - poll = async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) + LOCATION_BODY = "{'\"}" + response = TestArmPolling.mock_send("POST", 202, {"location": LOCATION_URL}) + poll = async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) with pytest.raises(DecodeError): await poll - LOCATION_BODY = '{' + LOCATION_BODY = "{" POLLING_STATUS = 203 - response = TestArmPolling.mock_send( - 'POST', 202, - {'location': LOCATION_URL}) - poll = async_poller(CLIENT, response, - TestArmPolling.mock_outputs, - AsyncARMPolling(0)) - with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization + response = TestArmPolling.mock_send("POST", 202, {"location": LOCATION_URL}) + poll = async_poller(CLIENT, response, TestArmPolling.mock_outputs, AsyncARMPolling(0)) + with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization await poll - assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode('ascii') + assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii") - LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) + LOCATION_BODY = json.dumps({"name": TEST_NAME}) POLLING_STATUS = 200 + def test_polling_with_path_format_arguments(): - method = AsyncARMPolling( - timeout=0, - path_format_arguments={"host": "host:3000", "accountName": "local"} - ) + method = AsyncARMPolling(timeout=0, path_format_arguments={"host": "host:3000", "accountName": "local"}) client = AsyncPipelineClient(base_url="http://{accountName}{host}") method._operation = LocationPolling() method._operation._location_url = "/results/1" method._client = client - assert "http://localhost:3000/results/1" == method._client.format_url(method._operation.get_polling_url(), **method._path_format_arguments) \ No newline at end of file + assert "http://localhost:3000/results/1" == method._client.format_url( + method._operation.get_polling_url(), **method._path_format_arguments + ) diff --git a/sdk/core/azure-mgmt-core/tests/asynctests/test_policies_async.py b/sdk/core/azure-mgmt-core/tests/asynctests/test_policies_async.py index d922594dbf1a..e8ca1123c788 100644 --- a/sdk/core/azure-mgmt-core/tests/asynctests/test_policies_async.py +++ b/sdk/core/azure-mgmt-core/tests/asynctests/test_policies_async.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # # The MIT License (MIT) @@ -21,26 +21,28 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from azure.mgmt.core import AsyncARMPipelineClient from azure.mgmt.core.policies import ARMHttpLoggingPolicy from azure.core.configuration import Configuration + def test_default_http_logging_policy(): config = Configuration() pipeline_client = AsyncARMPipelineClient(base_url="test", config=config) http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST + def test_pass_in_http_logging_policy(): config = Configuration() http_logging_policy = ARMHttpLoggingPolicy() - http_logging_policy.allowed_header_names.update( - {"x-ms-added-header"} - ) + http_logging_policy.allowed_header_names.update({"x-ms-added-header"}) config.http_logging_policy = http_logging_policy pipeline_client = AsyncARMPipelineClient(base_url="test", config=config) http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy - assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) \ No newline at end of file + assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union( + {"x-ms-added-header"} + ) diff --git a/sdk/core/azure-mgmt-core/tests/conftest.py b/sdk/core/azure-mgmt-core/tests/conftest.py index 0dc30804dae2..72916d0e7601 100644 --- a/sdk/core/azure-mgmt-core/tests/conftest.py +++ b/sdk/core/azure-mgmt-core/tests/conftest.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import json import os.path import sys @@ -31,9 +31,10 @@ CWD = os.path.dirname(__file__) + def pytest_addoption(parser): - parser.addoption("--runslow", action="store_true", - default=False, help="run slow tests") + parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") + def pytest_collection_modifyitems(config, items): if config.getoption("--runslow"): @@ -52,6 +53,4 @@ def user_password(): with open(filepath, "r") as fd: userpass = json.load(fd)["userpass"] return userpass["user"], userpass["password"] - raise ValueError("Create a {} file with a 'userpass' key and two keys 'user' and 'password'".format( - filepath - )) + raise ValueError("Create a {} file with a 'userpass' key and two keys 'user' and 'password'".format(filepath)) diff --git a/sdk/core/azure-mgmt-core/tests/test_arm_polling.py b/sdk/core/azure-mgmt-core/tests/test_arm_polling.py index 6e12379c5b2a..d1c6b3cda64c 100644 --- a/sdk/core/azure-mgmt-core/tests/test_arm_polling.py +++ b/sdk/core/azure-mgmt-core/tests/test_arm_polling.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import base64 import json import pickle @@ -39,15 +39,12 @@ from azure.core.pipeline import PipelineResponse, Pipeline from azure.core.pipeline.transport import RequestsTransportResponse, HttpTransport -from azure.core.polling.base_polling import ( - LongRunningOperation, - BadStatus, - LocationPolling -) +from azure.core.polling.base_polling import LongRunningOperation, BadStatus, LocationPolling from azure.mgmt.core.polling.arm_polling import ( ARMPolling, ) + class SimpleResource: """An implementation of Python 3 SimpleNamespace. Used to deserialize resource objects from response bodies where @@ -65,23 +62,29 @@ def __repr__(self): def __eq__(self, other): return self.__dict__ == other.__dict__ + class BadEndpointError(Exception): pass -TEST_NAME = 'foo' -RESPONSE_BODY = {'properties':{'provisioningState': 'InProgress'}} -ASYNC_BODY = json.dumps({ 'status': 'Succeeded' }) -ASYNC_URL = 'http://dummyurlFromAzureAsyncOPHeader_Return200' -LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) -LOCATION_URL = 'http://dummyurlurlFromLocationHeader_Return200' -RESOURCE_BODY = json.dumps({ 'name': TEST_NAME }) -RESOURCE_URL = 'http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1' -ERROR = 'http://dummyurl_ReturnError' + +TEST_NAME = "foo" +RESPONSE_BODY = {"properties": {"provisioningState": "InProgress"}} +ASYNC_BODY = json.dumps({"status": "Succeeded"}) +ASYNC_URL = "http://dummyurlFromAzureAsyncOPHeader_Return200" +LOCATION_BODY = json.dumps({"name": TEST_NAME}) +LOCATION_URL = "http://dummyurlurlFromLocationHeader_Return200" +RESOURCE_BODY = json.dumps({"name": TEST_NAME}) +RESOURCE_URL = "http://subscriptions/sub1/resourcegroups/g1/resourcetype1/resource1" +ERROR = "http://dummyurl_ReturnError" POLLING_STATUS = 200 CLIENT = PipelineClient("http://example.org") + + def mock_run(client_self, request, **kwargs): return TestArmPolling.mock_update(request.url, request.headers) + + CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -91,21 +94,23 @@ def pipeline_client_builder(): send will receive "request" and kwargs as any transport layer """ + def create_client(send_cb): class TestHttpTransport(HttpTransport): - def open(self): pass - def close(self): pass - def __exit__(self, *args, **kwargs): pass + def open(self): + pass + + def close(self): + pass + + def __exit__(self, *args, **kwargs): + pass def send(self, request, **kwargs): return send_cb(request, **kwargs) - return PipelineClient( - 'http://example.org/', - pipeline=Pipeline( - transport=TestHttpTransport() - ) - ) + return PipelineClient("http://example.org/", pipeline=Pipeline(transport=TestHttpTransport())) + return create_client @@ -113,105 +118,83 @@ def send(self, request, **kwargs): def deserialization_cb(): def cb(pipeline_response): return json.loads(pipeline_response.http_response.text()) + return cb def test_post(pipeline_client_builder, deserialization_cb): - # Test POST LRO with both Location and Azure-AsyncOperation - - # The initial response contains both Location and Azure-AsyncOperation, a 202 and no Body - initial_response = TestArmPolling.mock_send( - 'POST', - 202, - { - 'location': 'http://example.org/location', - 'azure-asyncoperation': 'http://example.org/async_monitor', - }, - '' - ) + # Test POST LRO with both Location and Azure-AsyncOperation + + # The initial response contains both Location and Azure-AsyncOperation, a 202 and no Body + initial_response = TestArmPolling.mock_send( + "POST", + 202, + { + "location": "http://example.org/location", + "azure-asyncoperation": "http://example.org/async_monitor", + }, + "", + ) + + def send(request, **kwargs): + assert request.method == "GET" + + if request.url == "http://example.org/location": + return TestArmPolling.mock_send("GET", 200, body={"location_result": True}).http_response + elif request.url == "http://example.org/async_monitor": + return TestArmPolling.mock_send("GET", 200, body={"status": "Succeeded"}).http_response + else: + pytest.fail("No other query allowed") + + client = pipeline_client_builder(send) + + # Test 1, LRO options with Location final state + poll = LROPoller( + client, initial_response, deserialization_cb, ARMPolling(0, lro_options={"final-state-via": "location"}) + ) + result = poll.result() + assert result["location_result"] == True + + # Test 2, LRO options with Azure-AsyncOperation final state + poll = LROPoller( + client, + initial_response, + deserialization_cb, + ARMPolling(0, lro_options={"final-state-via": "azure-async-operation"}), + ) + result = poll.result() + assert result["status"] == "Succeeded" + + # Test 3, "do the right thing" and use Location by default + poll = LROPoller(client, initial_response, deserialization_cb, ARMPolling(0)) + result = poll.result() + assert result["location_result"] == True + + # Test 4, location has no body + + def send(request, **kwargs): + assert request.method == "GET" + + if request.url == "http://example.org/location": + return TestArmPolling.mock_send("GET", 200, body=None).http_response + elif request.url == "http://example.org/async_monitor": + return TestArmPolling.mock_send("GET", 200, body={"status": "Succeeded"}).http_response + else: + pytest.fail("No other query allowed") + + client = pipeline_client_builder(send) - def send(request, **kwargs): - assert request.method == 'GET' - - if request.url == 'http://example.org/location': - return TestArmPolling.mock_send( - 'GET', - 200, - body={'location_result': True} - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestArmPolling.mock_send( - 'GET', - 200, - body={'status': 'Succeeded'} - ).http_response - else: - pytest.fail("No other query allowed") - - client = pipeline_client_builder(send) - - # Test 1, LRO options with Location final state - poll = LROPoller( - client, - initial_response, - deserialization_cb, - ARMPolling(0, lro_options={"final-state-via": "location"})) - result = poll.result() - assert result['location_result'] == True - - # Test 2, LRO options with Azure-AsyncOperation final state - poll = LROPoller( - client, - initial_response, - deserialization_cb, - ARMPolling(0, lro_options={"final-state-via": "azure-async-operation"})) - result = poll.result() - assert result['status'] == 'Succeeded' - - # Test 3, "do the right thing" and use Location by default - poll = LROPoller( - client, - initial_response, - deserialization_cb, - ARMPolling(0)) - result = poll.result() - assert result['location_result'] == True - - # Test 4, location has no body - - def send(request, **kwargs): - assert request.method == 'GET' - - if request.url == 'http://example.org/location': - return TestArmPolling.mock_send( - 'GET', - 200, - body=None - ).http_response - elif request.url == 'http://example.org/async_monitor': - return TestArmPolling.mock_send( - 'GET', - 200, - body={'status': 'Succeeded'} - ).http_response - else: - pytest.fail("No other query allowed") - - client = pipeline_client_builder(send) - - poll = LROPoller( - client, - initial_response, - deserialization_cb, - ARMPolling(0, lro_options={"final-state-via": "location"})) - result = poll.result() - assert result is None + poll = LROPoller( + client, initial_response, deserialization_cb, ARMPolling(0, lro_options={"final-state-via": "location"}) + ) + result = poll.result() + assert result is None class TestArmPolling(object): - convert = re.compile('([a-z0-9])([A-Z])') + convert = re.compile("([a-z0-9])([A-Z])") @staticmethod def mock_send(method, status, headers=None, body=RESPONSE_BODY): @@ -219,13 +202,11 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): headers = {} response = Response() response._content_consumed = True - response._content = json.dumps(body).encode('ascii') if body is not None else None + response._content = json.dumps(body).encode("ascii") if body is not None else None response.request = Request() response.request.method = method response.request.url = RESOURCE_URL - response.request.headers = { - 'x-ms-client-request-id': '67f4dd4e-6262-45e1-8bed-5c45cf23b6d9' - } + response.request.headers = {"x-ms-client-request-id": "67f4dd4e-6262-45e1-8bed-5c45cf23b6d9"} response.status_code = status response.headers = headers response.headers.update({"content-type": "application/json; charset=utf8"}) @@ -238,7 +219,7 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): response.request.headers, body, None, # form_content - None # stream_content + None, # stream_content ) return PipelineResponse( @@ -247,7 +228,7 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): request, response, ), - None # context + None, # context ) @staticmethod @@ -255,7 +236,7 @@ def mock_update(url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) - response.request.method = 'GET' + response.request.method = "GET" response.headers = headers or {} response.headers.update({"content-type": "application/json; charset=utf8"}) response.reason = "OK" @@ -263,13 +244,13 @@ def mock_update(url, headers=None): if url == ASYNC_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = ASYNC_BODY.encode('ascii') + response._content = ASYNC_BODY.encode("ascii") response.randomFieldFromPollAsyncOpHeader = None elif url == LOCATION_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = LOCATION_BODY.encode('ascii') + response._content = LOCATION_BODY.encode("ascii") response.randomFieldFromPollLocationHeader = None elif url == ERROR: @@ -278,19 +259,19 @@ def mock_update(url, headers=None): elif url == RESOURCE_URL: response.request.url = url response.status_code = POLLING_STATUS - response._content = RESOURCE_BODY.encode('ascii') + response._content = RESOURCE_BODY.encode("ascii") else: - raise Exception('URL does not match') + raise Exception("URL does not match") request = CLIENT._request( response.request.method, response.request.url, None, # params - {}, # request has no headers - None, # Request has no body + {}, # request has no headers + None, # Request has no body None, # form_content - None # stream_content + None, # stream_content ) return PipelineResponse( @@ -299,7 +280,7 @@ def mock_update(url, headers=None): request, response, ), - None # context + None, # context ) @staticmethod @@ -311,15 +292,13 @@ def mock_outputs(pipeline_response): raise DecodeError("Impossible to deserialize") body = json.loads(response.text()) - body = {TestArmPolling.convert.sub(r'\1_\2', k).lower(): v - for k, v in body.items()} - properties = body.setdefault('properties', {}) - if 'name' in body: - properties['name'] = body['name'] + body = {TestArmPolling.convert.sub(r"\1_\2", k).lower(): v for k, v in body.items()} + properties = body.setdefault("properties", {}) + if "name" in body: + properties["name"] = body["name"] if properties: - properties = {TestArmPolling.convert.sub(r'\1_\2', k).lower(): v - for k, v in properties.items()} - del body['properties'] + properties = {TestArmPolling.convert.sub(r"\1_\2", k).lower(): v for k, v in properties.items()} + del body["properties"] body.update(properties) resource = SimpleResource(**body) else: @@ -329,270 +308,204 @@ def mock_outputs(pipeline_response): @staticmethod def mock_deserialization_no_body(pipeline_response): - """Use this mock when you don't expect a return (last body irrelevant) - """ + """Use this mock when you don't expect a return (last body irrelevant)""" return None def test_long_running_put(self): - #TODO: Test custom header field + # TODO: Test custom header field # Test throw on non LRO related status code - response = TestArmPolling.mock_send('PUT', 1000, {}) + response = TestArmPolling.mock_send("PUT", 1000, {}) with pytest.raises(HttpResponseError): - LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)).result() + LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)).result() # Test with no polling necessary - response_body = { - 'properties':{'provisioningState': 'Succeeded'}, - 'name': TEST_NAME - } - response = TestArmPolling.mock_send( - 'PUT', 201, - {}, response_body - ) + response_body = {"properties": {"provisioningState": "Succeeded"}, "name": TEST_NAME} + response = TestArmPolling.mock_send("PUT", 201, {}, response_body) + def no_update_allowed(url, headers=None): raise ValueError("Should not try to update") - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0) - ) + + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert not hasattr(poll._polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(poll._polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'PUT', 201, - {'azure-asyncoperation': ASYNC_URL}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + response = TestArmPolling.mock_send("PUT", 201, {"azure-asyncoperation": ASYNC_URL}) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert not hasattr(poll._polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(poll._polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling location header - response = TestArmPolling.mock_send( - 'PUT', 201, - {'location': LOCATION_URL}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + response = TestArmPolling.mock_send("PUT", 201, {"location": LOCATION_URL}) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test polling initial payload invalid (SQLDb) response_body = {} # Empty will raise - response = TestArmPolling.mock_send( - 'PUT', 201, - {'location': LOCATION_URL}, response_body) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + response = TestArmPolling.mock_send("PUT", 201, {"location": LOCATION_URL}, response_body) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test fail to poll from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'PUT', 201, - {'azure-asyncoperation': ERROR}) + response = TestArmPolling.mock_send("PUT", 201, {"azure-asyncoperation": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)).result() + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)).result() # Test fail to poll from location header - response = TestArmPolling.mock_send( - 'PUT', 201, - {'location': ERROR}) + response = TestArmPolling.mock_send("PUT", 201, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)).result() + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)).result() def test_long_running_patch(self): # Test polling from location header response = TestArmPolling.mock_send( - 'PATCH', 202, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + "PATCH", 202, {"location": LOCATION_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test polling from azure-asyncoperation header response = TestArmPolling.mock_send( - 'PATCH', 202, - {'azure-asyncoperation': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + "PATCH", 202, {"azure-asyncoperation": ASYNC_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert not hasattr(poll._polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(poll._polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test polling from location header response = TestArmPolling.mock_send( - 'PATCH', 200, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + "PATCH", 200, {"location": LOCATION_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test polling from azure-asyncoperation header response = TestArmPolling.mock_send( - 'PATCH', 200, - {'azure-asyncoperation': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + "PATCH", 200, {"azure-asyncoperation": ASYNC_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert not hasattr(poll._polling_method._pipeline_response, 'randomFieldFromPollAsyncOpHeader') + assert not hasattr(poll._polling_method._pipeline_response, "randomFieldFromPollAsyncOpHeader") # Test fail to poll from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'PATCH', 202, - {'azure-asyncoperation': ERROR}) + response = TestArmPolling.mock_send("PATCH", 202, {"azure-asyncoperation": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)).result() + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)).result() # Test fail to poll from location header - response = TestArmPolling.mock_send( - 'PATCH', 202, - {'location': ERROR}) + response = TestArmPolling.mock_send("PATCH", 202, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)).result() + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)).result() def test_long_running_delete(self): # Test polling from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'DELETE', 202, - {'azure-asyncoperation': ASYNC_URL}, - body="" - ) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + response = TestArmPolling.mock_send("DELETE", 202, {"azure-asyncoperation": ASYNC_URL}, body="") + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) poll.wait() - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader + is None + ) def test_long_running_post_legacy(self): # Former oooooold tests to refactor one day to something more readble # Test polling from azure-asyncoperation header response = TestArmPolling.mock_send( - 'POST', 201, - {'azure-asyncoperation': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_deserialization_no_body, - ARMPolling(0)) + "POST", 201, {"azure-asyncoperation": ASYNC_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_deserialization_no_body, ARMPolling(0)) poll.wait() - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader + is None + ) # Test polling from azure-asyncoperation header response = TestArmPolling.mock_send( - 'POST', 202, - {'azure-asyncoperation': ASYNC_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_deserialization_no_body, - ARMPolling(0)) + "POST", 202, {"azure-asyncoperation": ASYNC_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_deserialization_no_body, ARMPolling(0)) poll.wait() - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader + is None + ) # Test polling from location header response = TestArmPolling.mock_send( - 'POST', 202, - {'location': LOCATION_URL}, - body={'properties':{'provisioningState': 'Succeeded'}}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + "POST", 202, {"location": LOCATION_URL}, body={"properties": {"provisioningState": "Succeeded"}} + ) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) assert poll.result().name == TEST_NAME - assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader is None + assert ( + poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollLocationHeader + is None + ) # Test fail to poll from azure-asyncoperation header - response = TestArmPolling.mock_send( - 'POST', 202, - {'azure-asyncoperation': ERROR}) + response = TestArmPolling.mock_send("POST", 202, {"azure-asyncoperation": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)).result() + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)).result() # Test fail to poll from location header - response = TestArmPolling.mock_send( - 'POST', 202, - {'location': ERROR}) + response = TestArmPolling.mock_send("POST", 202, {"location": ERROR}) with pytest.raises(BadEndpointError): - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)).result() + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)).result() def test_long_running_negative(self): global LOCATION_BODY global POLLING_STATUS # Test LRO PUT throws for invalid json - LOCATION_BODY = '{' - response = TestArmPolling.mock_send( - 'POST', 202, - {'location': LOCATION_URL}) - poll = LROPoller( - CLIENT, - response, - TestArmPolling.mock_outputs, - ARMPolling(0) - ) + LOCATION_BODY = "{" + response = TestArmPolling.mock_send("POST", 202, {"location": LOCATION_URL}) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) with pytest.raises(DecodeError): poll.result() - LOCATION_BODY = '{\'"}' - response = TestArmPolling.mock_send( - 'POST', 202, - {'location': LOCATION_URL}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) + LOCATION_BODY = "{'\"}" + response = TestArmPolling.mock_send("POST", 202, {"location": LOCATION_URL}) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) with pytest.raises(DecodeError): poll.result() - LOCATION_BODY = '{' + LOCATION_BODY = "{" POLLING_STATUS = 203 - response = TestArmPolling.mock_send( - 'POST', 202, - {'location': LOCATION_URL}) - poll = LROPoller(CLIENT, response, - TestArmPolling.mock_outputs, - ARMPolling(0)) - with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization + response = TestArmPolling.mock_send("POST", 202, {"location": LOCATION_URL}) + poll = LROPoller(CLIENT, response, TestArmPolling.mock_outputs, ARMPolling(0)) + with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization poll.result() - assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode('ascii') + assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii") - LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) + LOCATION_BODY = json.dumps({"name": TEST_NAME}) POLLING_STATUS = 200 def test_polling_with_path_format_arguments(self): - method = ARMPolling( - timeout=0, - path_format_arguments={"host": "host:3000", "accountName": "local"} - ) + method = ARMPolling(timeout=0, path_format_arguments={"host": "host:3000", "accountName": "local"}) client = PipelineClient(base_url="http://{accountName}{host}") method._operation = LocationPolling() method._operation._location_url = "/results/1" method._client = client - assert "http://localhost:3000/results/1" == method._client.format_url(method._operation.get_polling_url(), **method._path_format_arguments) - + assert "http://localhost:3000/results/1" == method._client.format_url( + method._operation.get_polling_url(), **method._path_format_arguments + ) diff --git a/sdk/core/azure-mgmt-core/tests/test_authentication.py b/sdk/core/azure-mgmt-core/tests/test_authentication.py index 7a4b5bfa1111..c1f3590e5050 100644 --- a/sdk/core/azure-mgmt-core/tests/test_authentication.py +++ b/sdk/core/azure-mgmt-core/tests/test_authentication.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import base64 import time diff --git a/sdk/core/azure-mgmt-core/tests/test_mgmt_exceptions.py b/sdk/core/azure-mgmt-core/tests/test_mgmt_exceptions.py index 8e1e40a1c3c8..457434917113 100644 --- a/sdk/core/azure-mgmt-core/tests/test_mgmt_exceptions.py +++ b/sdk/core/azure-mgmt-core/tests/test_mgmt_exceptions.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import functools import json @@ -34,12 +34,13 @@ ARMError = functools.partial(HttpResponseError, error_format=ARMErrorFormat) + def _build_response(json_body): class MockResponse(_HttpResponseBase): def __init__(self): super(MockResponse, self).__init__( request=None, - internal_response = None, + internal_response=None, ) self.status_code = 400 self.reason = "Bad Request" @@ -66,26 +67,17 @@ def test_arm_exception(): "message": "$search query option not supported", } ], - "innererror": { - "customKey": "customValue" - }, - "additionalInfo": [ - { - "type": "SomeErrorType", - "info": { - "customKey": "customValue" - } - } - ] + "innererror": {"customKey": "customValue"}, + "additionalInfo": [{"type": "SomeErrorType", "info": {"customKey": "customValue"}}], } } cloud_exp = ARMError(response=_build_response(json.dumps(message).encode("utf-8"))) - assert cloud_exp.error.target == 'query' - assert cloud_exp.error.details[0].target == '$search' - assert cloud_exp.error.innererror['customKey'] == 'customValue' - assert cloud_exp.error.additional_info[0].type == 'SomeErrorType' - assert cloud_exp.error.additional_info[0].info['customKey'] == 'customValue' - assert 'customValue' in str(cloud_exp) + assert cloud_exp.error.target == "query" + assert cloud_exp.error.details[0].target == "$search" + assert cloud_exp.error.innererror["customKey"] == "customValue" + assert cloud_exp.error.additional_info[0].type == "SomeErrorType" + assert cloud_exp.error.additional_info[0].info["customKey"] == "customValue" + assert "customValue" in str(cloud_exp) message = { "error": { @@ -102,60 +94,58 @@ def test_arm_exception(): "type": "PolicyViolation", "info": { "policyDefinitionDisplayName": "Allowed locations", - "policyAssignmentParameters": { - "listOfAllowedLocations": { - "value": [ - "westus" - ] - } - } - } + "policyAssignmentParameters": {"listOfAllowedLocations": {"value": ["westus"]}}, + }, } - ] + ], } ], - "additionalInfo": [ - { - "type": "SomeErrorType", - "info": { - "customKey": "customValue" - } - } - ] + "additionalInfo": [{"type": "SomeErrorType", "info": {"customKey": "customValue"}}], } } cloud_exp = ARMError(response=_build_response(json.dumps(message).encode("utf-8"))) - assert cloud_exp.error.target == 'query' - assert cloud_exp.error.details[0].target == '$search' - assert cloud_exp.error.additional_info[0].type == 'SomeErrorType' - assert cloud_exp.error.additional_info[0].info['customKey'] == 'customValue' - assert cloud_exp.error.details[0].additional_info[0].type == 'PolicyViolation' - assert cloud_exp.error.details[0].additional_info[0].info['policyDefinitionDisplayName'] == 'Allowed locations' - assert cloud_exp.error.details[0].additional_info[0].info['policyAssignmentParameters']['listOfAllowedLocations']['value'][0] == 'westus' - assert 'customValue' in str(cloud_exp) - + assert cloud_exp.error.target == "query" + assert cloud_exp.error.details[0].target == "$search" + assert cloud_exp.error.additional_info[0].type == "SomeErrorType" + assert cloud_exp.error.additional_info[0].info["customKey"] == "customValue" + assert cloud_exp.error.details[0].additional_info[0].type == "PolicyViolation" + assert cloud_exp.error.details[0].additional_info[0].info["policyDefinitionDisplayName"] == "Allowed locations" + assert ( + cloud_exp.error.details[0] + .additional_info[0] + .info["policyAssignmentParameters"]["listOfAllowedLocations"]["value"][0] + == "westus" + ) + assert "customValue" in str(cloud_exp) error = ARMError(response=_build_response(b"{{")) assert "Bad Request" in error.message - error = ARMError(response=_build_response(b'{"error":{"code":"Conflict","message":"The maximum number of Free ServerFarms allowed in a Subscription is 10.","target":null,"details":[{"message":"The maximum number of Free ServerFarms allowed in a Subscription is 10."},{"code":"Conflict"},{"errorentity":{"code":"Conflict","message":"The maximum number of Free ServerFarms allowed in a Subscription is 10.","extendedCode":"59301","messageTemplate":"The maximum number of {0} ServerFarms allowed in a Subscription is {1}.","parameters":["Free","10"],"innerErrors":null}}],"innererror":null}}')) - assert error.error.code == "Conflict" - - message = json.dumps({ - "error": { - "code": "BadArgument", - "message": "The provided database 'foo' has an invalid username.", - "target": "query", - "details": [ - { - "code": "301", - "target": "$search", - "message": "$search query option not supported", - } - ] - }}).encode('utf-8') + error = ARMError( + response=_build_response( + b'{"error":{"code":"Conflict","message":"The maximum number of Free ServerFarms allowed in a Subscription is 10.","target":null,"details":[{"message":"The maximum number of Free ServerFarms allowed in a Subscription is 10."},{"code":"Conflict"},{"errorentity":{"code":"Conflict","message":"The maximum number of Free ServerFarms allowed in a Subscription is 10.","extendedCode":"59301","messageTemplate":"The maximum number of {0} ServerFarms allowed in a Subscription is {1}.","parameters":["Free","10"],"innerErrors":null}}],"innererror":null}}' + ) + ) + assert error.error.code == "Conflict" + + message = json.dumps( + { + "error": { + "code": "BadArgument", + "message": "The provided database 'foo' has an invalid username.", + "target": "query", + "details": [ + { + "code": "301", + "target": "$search", + "message": "$search query option not supported", + } + ], + } + } + ).encode("utf-8") error = ARMError(response=_build_response(message)) - assert error.error.code == "BadArgument" + assert error.error.code == "BadArgument" # See https://github.com/Azure/msrestazure-for-python/issues/54 response_text = b'"{\\"error\\": {\\"code\\": \\"ResourceGroupNotFound\\", \\"message\\": \\"Resource group \'res_grp\' could not be found.\\"}}"' diff --git a/sdk/core/azure-mgmt-core/tests/test_policies.py b/sdk/core/azure-mgmt-core/tests/test_policies.py index 5a9b4b972448..b70897f0f403 100644 --- a/sdk/core/azure-mgmt-core/tests/test_policies.py +++ b/sdk/core/azure-mgmt-core/tests/test_policies.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # # The MIT License (MIT) @@ -21,7 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import json import time @@ -37,16 +37,15 @@ ) from azure.mgmt.core import ARMPipelineClient -from azure.mgmt.core.policies import ( - ARMAutoResourceProviderRegistrationPolicy, - ARMHttpLoggingPolicy -) +from azure.mgmt.core.policies import ARMAutoResourceProviderRegistrationPolicy, ARMHttpLoggingPolicy + @pytest.fixture def sleepless(monkeypatch): def sleep(_): pass - monkeypatch.setattr(time, 'sleep', sleep) + + monkeypatch.setattr(time, "sleep", sleep) @httpretty.activate @@ -59,51 +58,55 @@ def test_register_rp_policy(): - We call again the first endpoint and this time this succeed """ - provider_url = ("https://management.azure.com/" - "subscriptions/12345678-9abc-def0-0000-000000000000/" - "resourceGroups/clitest.rg000001/" - "providers/Microsoft.Sql/servers/ygserver123?api-version=2014-04-01") + provider_url = ( + "https://management.azure.com/" + "subscriptions/12345678-9abc-def0-0000-000000000000/" + "resourceGroups/clitest.rg000001/" + "providers/Microsoft.Sql/servers/ygserver123?api-version=2014-04-01" + ) - provider_error = ('{"error":{"code":"MissingSubscriptionRegistration", ' - '"message":"The subscription registration is in \'Unregistered\' state. ' - 'The subscription must be registered to use namespace \'Microsoft.Sql\'. ' - 'See https://aka.ms/rps-not-found for how to register subscriptions."}}') + provider_error = ( + '{"error":{"code":"MissingSubscriptionRegistration", ' + '"message":"The subscription registration is in \'Unregistered\' state. ' + "The subscription must be registered to use namespace 'Microsoft.Sql'. " + 'See https://aka.ms/rps-not-found for how to register subscriptions."}}' + ) provider_success = '{"success": true}' - httpretty.register_uri(httpretty.PUT, - provider_url, - responses=[ - httpretty.Response(body=provider_error, status=409), - httpretty.Response(body=provider_success), - ], - content_type="application/json") + httpretty.register_uri( + httpretty.PUT, + provider_url, + responses=[ + httpretty.Response(body=provider_error, status=409), + httpretty.Response(body=provider_success), + ], + content_type="application/json", + ) - register_post_url = ("https://management.azure.com/" - "subscriptions/12345678-9abc-def0-0000-000000000000/" - "providers/Microsoft.Sql/register?api-version=2016-02-01") + register_post_url = ( + "https://management.azure.com/" + "subscriptions/12345678-9abc-def0-0000-000000000000/" + "providers/Microsoft.Sql/register?api-version=2016-02-01" + ) - register_post_result = { - "registrationState":"Registering" - } + register_post_result = {"registrationState": "Registering"} - register_get_url = ("https://management.azure.com/" - "subscriptions/12345678-9abc-def0-0000-000000000000/" - "providers/Microsoft.Sql?api-version=2016-02-01") + register_get_url = ( + "https://management.azure.com/" + "subscriptions/12345678-9abc-def0-0000-000000000000/" + "providers/Microsoft.Sql?api-version=2016-02-01" + ) - register_get_result = { - "registrationState":"Registered" - } + register_get_result = {"registrationState": "Registered"} - httpretty.register_uri(httpretty.POST, - register_post_url, - body=json.dumps(register_post_result), - content_type="application/json") + httpretty.register_uri( + httpretty.POST, register_post_url, body=json.dumps(register_post_result), content_type="application/json" + ) - httpretty.register_uri(httpretty.GET, - register_get_url, - body=json.dumps(register_get_result), - content_type="application/json") + httpretty.register_uri( + httpretty.GET, register_get_url, body=json.dumps(register_get_result), content_type="application/json" + ) request = HttpRequest("PUT", provider_url) policies = [ @@ -112,7 +115,7 @@ def test_register_rp_policy(): with Pipeline(RequestsTransport(), policies=policies) as pipeline: response = pipeline.run(request) - assert json.loads(response.http_response.text())['success'] + assert json.loads(response.http_response.text())["success"] @httpretty.activate @@ -124,34 +127,39 @@ def test_register_failed_policy(): - This POST failed """ - provider_url = ("https://management.azure.com/" - "subscriptions/12345678-9abc-def0-0000-000000000000/" - "resourceGroups/clitest.rg000001/" - "providers/Microsoft.Sql/servers/ygserver123?api-version=2014-04-01") + provider_url = ( + "https://management.azure.com/" + "subscriptions/12345678-9abc-def0-0000-000000000000/" + "resourceGroups/clitest.rg000001/" + "providers/Microsoft.Sql/servers/ygserver123?api-version=2014-04-01" + ) - provider_error = ('{"error":{"code":"MissingSubscriptionRegistration", ' - '"message":"The subscription registration is in \'Unregistered\' state. ' - 'The subscription must be registered to use namespace \'Microsoft.Sql\'. ' - 'See https://aka.ms/rps-not-found for how to register subscriptions."}}') + provider_error = ( + '{"error":{"code":"MissingSubscriptionRegistration", ' + '"message":"The subscription registration is in \'Unregistered\' state. ' + "The subscription must be registered to use namespace 'Microsoft.Sql'. " + 'See https://aka.ms/rps-not-found for how to register subscriptions."}}' + ) provider_success = '{"success": true}' - httpretty.register_uri(httpretty.PUT, - provider_url, - responses=[ - httpretty.Response(body=provider_error, status=409), - httpretty.Response(body=provider_success), - ], - content_type="application/json") + httpretty.register_uri( + httpretty.PUT, + provider_url, + responses=[ + httpretty.Response(body=provider_error, status=409), + httpretty.Response(body=provider_success), + ], + content_type="application/json", + ) - register_post_url = ("https://management.azure.com/" - "subscriptions/12345678-9abc-def0-0000-000000000000/" - "providers/Microsoft.Sql/register?api-version=2016-02-01") + register_post_url = ( + "https://management.azure.com/" + "subscriptions/12345678-9abc-def0-0000-000000000000/" + "providers/Microsoft.Sql/register?api-version=2016-02-01" + ) - httpretty.register_uri(httpretty.POST, - register_post_url, - status=409, - content_type="application/json") + httpretty.register_uri(httpretty.POST, register_post_url, status=409, content_type="application/json") request = HttpRequest("PUT", provider_url) policies = [ @@ -162,6 +170,7 @@ def test_register_failed_policy(): assert response.http_response.status_code == 409 + def test_default_http_logging_policy(): config = Configuration() pipeline_client = ARMPipelineClient(base_url="test", config=config) @@ -169,15 +178,18 @@ def test_default_http_logging_policy(): assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_ALLOWLIST + def test_pass_in_http_logging_policy(): config = Configuration() http_logging_policy = ARMHttpLoggingPolicy() - http_logging_policy.allowed_header_names.update( - {"x-ms-added-header"} - ) + http_logging_policy.allowed_header_names.update({"x-ms-added-header"}) config.http_logging_policy = http_logging_policy pipeline_client = ARMPipelineClient(base_url="test", config=config) http_logging_policy = pipeline_client._pipeline._impl_policies[-1]._policy - assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_ALLOWLIST.union({"x-ms-added-header"}) - assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) \ No newline at end of file + assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_ALLOWLIST.union( + {"x-ms-added-header"} + ) + assert http_logging_policy.allowed_header_names == ARMHttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union( + {"x-ms-added-header"} + ) diff --git a/sdk/core/azure-mgmt-core/tests/test_tools.py b/sdk/core/azure-mgmt-core/tests/test_tools.py index 831bf9b4835c..ba762da0066d 100644 --- a/sdk/core/azure-mgmt-core/tests/test_tools.py +++ b/sdk/core/azure-mgmt-core/tests/test_tools.py @@ -1,4 +1,4 @@ -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # # The MIT License (MIT) @@ -21,351 +21,333 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import unittest from azure.mgmt.core.tools import parse_resource_id, is_valid_resource_id, resource_id, is_valid_resource_name -class TestTools(unittest.TestCase): +class TestTools(unittest.TestCase): def test_resource_parse(self): """ Tests resource id parsing, reforming, and validation. """ tests = [ { - 'resource_id': '/subscriptions/fakesub/resourcegroups/testgroup/providers' - '/Microsoft.Storage/storageAccounts/foo/providers' - '/Microsoft.Authorization/locks/bar', - 'expected': { - 'name': 'foo', - 'type': 'storageAccounts', - 'namespace': 'Microsoft.Storage', - 'child_name_1': 'bar', - 'child_namespace_1': 'Microsoft.Authorization', - 'child_type_1': 'locks', - 'child_parent_1': 'storageAccounts/foo/providers/Microsoft.Authorization/', - 'resource_group': 'testgroup', - 'subscription': 'fakesub', - } + "resource_id": "/subscriptions/fakesub/resourcegroups/testgroup/providers" + "/Microsoft.Storage/storageAccounts/foo/providers" + "/Microsoft.Authorization/locks/bar", + "expected": { + "name": "foo", + "type": "storageAccounts", + "namespace": "Microsoft.Storage", + "child_name_1": "bar", + "child_namespace_1": "Microsoft.Authorization", + "child_type_1": "locks", + "child_parent_1": "storageAccounts/foo/providers/Microsoft.Authorization/", + "resource_group": "testgroup", + "subscription": "fakesub", + }, }, { - 'resource_id': '/subscriptions/fakesub/resourcegroups/testgroup/providers' - '/Microsoft.Storage/storageAccounts/foo' - '/locks/bar', - 'expected': { - 'name': 'foo', - 'type': 'storageAccounts', - 'namespace': 'Microsoft.Storage', - 'child_name_1': 'bar', - 'child_type_1': 'locks', - 'child_parent_1': 'storageAccounts/foo/', - 'resource_group': 'testgroup', - 'subscription': 'fakesub', - } + "resource_id": "/subscriptions/fakesub/resourcegroups/testgroup/providers" + "/Microsoft.Storage/storageAccounts/foo" + "/locks/bar", + "expected": { + "name": "foo", + "type": "storageAccounts", + "namespace": "Microsoft.Storage", + "child_name_1": "bar", + "child_type_1": "locks", + "child_parent_1": "storageAccounts/foo/", + "resource_group": "testgroup", + "subscription": "fakesub", + }, }, { - 'resource_id': '/subscriptions/fakesub/resourcegroups/testgroup/providers' - '/Microsoft.Storage/storageAccounts/foo/providers' - '/Microsoft.Authorization/locks/bar/providers/Microsoft.Network/' - 'nets/gc', - 'expected': { - 'name': 'foo', - 'type': 'storageAccounts', - 'namespace': 'Microsoft.Storage', - 'child_name_1': 'bar', - 'child_namespace_1': 'Microsoft.Authorization', - 'child_type_1': 'locks', - 'child_parent_1': 'storageAccounts/foo/providers/Microsoft.Authorization/', - 'child_name_2': 'gc', - 'child_namespace_2': 'Microsoft.Network', - 'child_type_2': 'nets', - 'child_parent_2': 'storageAccounts/foo/providers/Microsoft.Authorization/' - 'locks/bar/providers/Microsoft.Network/', - 'resource_group': 'testgroup', - 'subscription': 'fakesub', - } + "resource_id": "/subscriptions/fakesub/resourcegroups/testgroup/providers" + "/Microsoft.Storage/storageAccounts/foo/providers" + "/Microsoft.Authorization/locks/bar/providers/Microsoft.Network/" + "nets/gc", + "expected": { + "name": "foo", + "type": "storageAccounts", + "namespace": "Microsoft.Storage", + "child_name_1": "bar", + "child_namespace_1": "Microsoft.Authorization", + "child_type_1": "locks", + "child_parent_1": "storageAccounts/foo/providers/Microsoft.Authorization/", + "child_name_2": "gc", + "child_namespace_2": "Microsoft.Network", + "child_type_2": "nets", + "child_parent_2": "storageAccounts/foo/providers/Microsoft.Authorization/" + "locks/bar/providers/Microsoft.Network/", + "resource_group": "testgroup", + "subscription": "fakesub", + }, }, { - 'resource_id': '/subscriptions/fakesub/resourcegroups/testgroup/providers' - '/Microsoft.Storage/storageAccounts/foo' - '/locks/bar/nets/gc', - 'expected': { - 'name': 'foo', - 'type': 'storageAccounts', - 'namespace': 'Microsoft.Storage', - 'child_name_1': 'bar', - 'child_type_1': 'locks', - 'child_parent_1': 'storageAccounts/foo/', - 'child_name_2': 'gc', - 'child_type_2': 'nets', - 'child_parent_2': 'storageAccounts/foo/locks/bar/', - 'resource_group': 'testgroup', - 'subscription': 'fakesub', - } + "resource_id": "/subscriptions/fakesub/resourcegroups/testgroup/providers" + "/Microsoft.Storage/storageAccounts/foo" + "/locks/bar/nets/gc", + "expected": { + "name": "foo", + "type": "storageAccounts", + "namespace": "Microsoft.Storage", + "child_name_1": "bar", + "child_type_1": "locks", + "child_parent_1": "storageAccounts/foo/", + "child_name_2": "gc", + "child_type_2": "nets", + "child_parent_2": "storageAccounts/foo/locks/bar/", + "resource_group": "testgroup", + "subscription": "fakesub", + }, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg/providers/' - 'Microsoft.Provider1/resourceType1/name1', - 'expected': { - 'subscription': 'mySub', - 'resource_group': 'myRg', - 'namespace': 'Microsoft.Provider1', - 'type': 'resourceType1', - 'name': 'name1', - 'resource_parent': '', - 'resource_namespace': 'Microsoft.Provider1', - 'resource_type': 'resourceType1', - 'resource_name': 'name1' - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg/providers/" + "Microsoft.Provider1/resourceType1/name1", + "expected": { + "subscription": "mySub", + "resource_group": "myRg", + "namespace": "Microsoft.Provider1", + "type": "resourceType1", + "name": "name1", + "resource_parent": "", + "resource_namespace": "Microsoft.Provider1", + "resource_type": "resourceType1", + "resource_name": "name1", + }, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg/providers/' - 'Microsoft.Provider1/resourceType1/name1/resourceType2/name2', - 'expected': { - 'subscription': 'mySub', - 'resource_group': 'myRg', - 'namespace': 'Microsoft.Provider1', - 'type': 'resourceType1', - 'name': 'name1', - 'child_namespace_1': None, - 'child_type_1': 'resourceType2', - 'child_name_1': 'name2', - 'child_parent_1': 'resourceType1/name1/', - 'resource_parent': 'resourceType1/name1/', - 'resource_namespace': 'Microsoft.Provider1', - 'resource_type': 'resourceType2', - 'resource_name': 'name2' - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg/providers/" + "Microsoft.Provider1/resourceType1/name1/resourceType2/name2", + "expected": { + "subscription": "mySub", + "resource_group": "myRg", + "namespace": "Microsoft.Provider1", + "type": "resourceType1", + "name": "name1", + "child_namespace_1": None, + "child_type_1": "resourceType2", + "child_name_1": "name2", + "child_parent_1": "resourceType1/name1/", + "resource_parent": "resourceType1/name1/", + "resource_namespace": "Microsoft.Provider1", + "resource_type": "resourceType2", + "resource_name": "name2", + }, }, { - 'resource_id': '/subscriptions/00000/resourceGroups/myRg/providers/' - 'Microsoft.RecoveryServices/vaults/vault_name/backupFabrics/' - 'fabric_name/protectionContainers/container_name/' - 'protectedItems/item_name/recoveryPoint/recovery_point_guid', - 'expected': { - 'subscription': '00000', - 'resource_group': 'myRg', - 'namespace': 'Microsoft.RecoveryServices', - 'type': 'vaults', - 'name': 'vault_name', - 'child_type_1': 'backupFabrics', - 'child_name_1': 'fabric_name', - 'child_parent_1': 'vaults/vault_name/', - 'child_type_2': 'protectionContainers', - 'child_name_2': 'container_name', - 'child_parent_2': 'vaults/vault_name/backupFabrics/fabric_name/', - 'child_type_3': 'protectedItems', - 'child_name_3': 'item_name', - 'child_parent_3': 'vaults/vault_name/backupFabrics/fabric_name/' - 'protectionContainers/container_name/', - 'child_type_4': 'recoveryPoint', - 'child_name_4': 'recovery_point_guid', - 'child_parent_4': 'vaults/vault_name/backupFabrics/fabric_name/' - 'protectionContainers/container_name/protectedItems/' - 'item_name/', - 'resource_parent': 'vaults/vault_name/backupFabrics/fabric_name/' - 'protectionContainers/container_name/protectedItems/' - 'item_name/', - 'resource_namespace': 'Microsoft.RecoveryServices', - 'resource_type': 'recoveryPoint', - 'resource_name': 'recovery_point_guid' - } + "resource_id": "/subscriptions/00000/resourceGroups/myRg/providers/" + "Microsoft.RecoveryServices/vaults/vault_name/backupFabrics/" + "fabric_name/protectionContainers/container_name/" + "protectedItems/item_name/recoveryPoint/recovery_point_guid", + "expected": { + "subscription": "00000", + "resource_group": "myRg", + "namespace": "Microsoft.RecoveryServices", + "type": "vaults", + "name": "vault_name", + "child_type_1": "backupFabrics", + "child_name_1": "fabric_name", + "child_parent_1": "vaults/vault_name/", + "child_type_2": "protectionContainers", + "child_name_2": "container_name", + "child_parent_2": "vaults/vault_name/backupFabrics/fabric_name/", + "child_type_3": "protectedItems", + "child_name_3": "item_name", + "child_parent_3": "vaults/vault_name/backupFabrics/fabric_name/" + "protectionContainers/container_name/", + "child_type_4": "recoveryPoint", + "child_name_4": "recovery_point_guid", + "child_parent_4": "vaults/vault_name/backupFabrics/fabric_name/" + "protectionContainers/container_name/protectedItems/" + "item_name/", + "resource_parent": "vaults/vault_name/backupFabrics/fabric_name/" + "protectionContainers/container_name/protectedItems/" + "item_name/", + "resource_namespace": "Microsoft.RecoveryServices", + "resource_type": "recoveryPoint", + "resource_name": "recovery_point_guid", + }, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg/providers/' - 'Microsoft.Provider1/resourceType1/name1/resourceType2/name2/' - 'providers/Microsoft.Provider3/resourceType3/name3', - 'expected': { - 'subscription': 'mySub', - 'resource_group': 'myRg', - 'namespace': 'Microsoft.Provider1', - 'type': 'resourceType1', - 'name': 'name1', - 'child_namespace_1': None, - 'child_type_1': 'resourceType2', - 'child_name_1': 'name2', - 'child_parent_1' : 'resourceType1/name1/', - 'child_namespace_2': 'Microsoft.Provider3', - 'child_type_2': 'resourceType3', - 'child_name_2': 'name3', - 'child_parent_2': 'resourceType1/name1/resourceType2/name2/' - 'providers/Microsoft.Provider3/', - 'resource_parent': 'resourceType1/name1/resourceType2/name2/' - 'providers/Microsoft.Provider3/', - 'resource_namespace': 'Microsoft.Provider1', - 'resource_type': 'resourceType3', - 'resource_name': 'name3' - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg/providers/" + "Microsoft.Provider1/resourceType1/name1/resourceType2/name2/" + "providers/Microsoft.Provider3/resourceType3/name3", + "expected": { + "subscription": "mySub", + "resource_group": "myRg", + "namespace": "Microsoft.Provider1", + "type": "resourceType1", + "name": "name1", + "child_namespace_1": None, + "child_type_1": "resourceType2", + "child_name_1": "name2", + "child_parent_1": "resourceType1/name1/", + "child_namespace_2": "Microsoft.Provider3", + "child_type_2": "resourceType3", + "child_name_2": "name3", + "child_parent_2": "resourceType1/name1/resourceType2/name2/" "providers/Microsoft.Provider3/", + "resource_parent": "resourceType1/name1/resourceType2/name2/" "providers/Microsoft.Provider3/", + "resource_namespace": "Microsoft.Provider1", + "resource_type": "resourceType3", + "resource_name": "name3", + }, }, { - 'resource_id': '/subscriptions/fakesub/providers/Microsoft.Authorization' - '/locks/foo', - 'expected': { - 'name': 'foo', - 'type': 'locks', - 'namespace': 'Microsoft.Authorization', - 'subscription': 'fakesub', - } + "resource_id": "/subscriptions/fakesub/providers/Microsoft.Authorization" "/locks/foo", + "expected": { + "name": "foo", + "type": "locks", + "namespace": "Microsoft.Authorization", + "subscription": "fakesub", + }, }, { - 'resource_id': '/Subscriptions/fakesub/providers/Microsoft.Authorization' - '/locks/foo', - 'expected': { - 'name': 'foo', - 'type': 'locks', - 'namespace': 'Microsoft.Authorization', - 'subscription': 'fakesub', - } + "resource_id": "/Subscriptions/fakesub/providers/Microsoft.Authorization" "/locks/foo", + "expected": { + "name": "foo", + "type": "locks", + "namespace": "Microsoft.Authorization", + "subscription": "fakesub", + }, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg', - 'expected': { - 'subscription': 'mySub', - 'resource_group': 'myRg' - } - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg", + "expected": {"subscription": "mySub", "resource_group": "myRg"}, + }, ] for test in tests: - self.assertTrue(is_valid_resource_id(test['resource_id'])) - kwargs = parse_resource_id(test['resource_id']) - for key in test['expected']: + self.assertTrue(is_valid_resource_id(test["resource_id"])) + kwargs = parse_resource_id(test["resource_id"]) + for key in test["expected"]: try: - self.assertEqual(kwargs[key], test['expected'][key]) + self.assertEqual(kwargs[key], test["expected"][key]) except KeyError: - self.assertTrue(key not in kwargs and test['expected'][key] is None) + self.assertTrue(key not in kwargs and test["expected"][key] is None) invalid_ids = [ - '/subscriptions/fakesub/resourceGroups/myRg/type1/name1', - '/subscriptions/fakesub/resourceGroups/myRg/providers/Microsoft.Provider/foo', - '/subscriptions/fakesub/resourceGroups/myRg/providers/namespace/type/name/type1', - '/subscriptions/fakesub/resourceGroups/', - '/subscriptions//resourceGroups/' + "/subscriptions/fakesub/resourceGroups/myRg/type1/name1", + "/subscriptions/fakesub/resourceGroups/myRg/providers/Microsoft.Provider/foo", + "/subscriptions/fakesub/resourceGroups/myRg/providers/namespace/type/name/type1", + "/subscriptions/fakesub/resourceGroups/", + "/subscriptions//resourceGroups/", ] for invalid_id in invalid_ids: self.assertFalse(is_valid_resource_id(invalid_id)) tests = [ { - 'resource_id': '/subscriptions/fakesub/resourcegroups/testgroup/providers' - '/Microsoft.Storage/storageAccounts/foo/providers' - '/Microsoft.Authorization/locks/bar', - 'id_args': { - 'name': 'foo', - 'type': 'storageAccounts', - 'namespace': 'Microsoft.Storage', - 'child_name_1': 'bar', - 'child_namespace_1': 'Microsoft.Authorization', - 'child_type_1': 'locks', - 'resource_group': 'testgroup', - 'subscription': 'fakesub', - } + "resource_id": "/subscriptions/fakesub/resourcegroups/testgroup/providers" + "/Microsoft.Storage/storageAccounts/foo/providers" + "/Microsoft.Authorization/locks/bar", + "id_args": { + "name": "foo", + "type": "storageAccounts", + "namespace": "Microsoft.Storage", + "child_name_1": "bar", + "child_namespace_1": "Microsoft.Authorization", + "child_type_1": "locks", + "resource_group": "testgroup", + "subscription": "fakesub", + }, }, { - 'resource_id': '/subscriptions/fakesub/resourcegroups/testgroup/providers' - '/Microsoft.Storage/storageAccounts/foo' - '/locks/bar', - 'id_args': { - 'name': 'foo', - 'type': 'storageAccounts', - 'namespace': 'Microsoft.Storage', - 'child_name_1': 'bar', - 'child_type_1': 'locks', - 'resource_group': 'testgroup', - 'subscription': 'fakesub', - } + "resource_id": "/subscriptions/fakesub/resourcegroups/testgroup/providers" + "/Microsoft.Storage/storageAccounts/foo" + "/locks/bar", + "id_args": { + "name": "foo", + "type": "storageAccounts", + "namespace": "Microsoft.Storage", + "child_name_1": "bar", + "child_type_1": "locks", + "resource_group": "testgroup", + "subscription": "fakesub", + }, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg/providers/' - 'Microsoft.Provider1/resourceType1/name1/resourceType2/name2/' - 'providers/Microsoft.Provider3/resourceType3/name3', - 'id_args': { - 'subscription': 'mySub', - 'resource_group': 'myRg', - 'namespace': 'Microsoft.Provider1', - 'type': 'resourceType1', - 'name': 'name1', - 'child_type_1': 'resourceType2', - 'child_name_1': 'name2', - 'child_namespace_2': 'Microsoft.Provider3', - 'child_type_2': 'resourceType3', - 'child_name_2': 'name3' - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg/providers/" + "Microsoft.Provider1/resourceType1/name1/resourceType2/name2/" + "providers/Microsoft.Provider3/resourceType3/name3", + "id_args": { + "subscription": "mySub", + "resource_group": "myRg", + "namespace": "Microsoft.Provider1", + "type": "resourceType1", + "name": "name1", + "child_type_1": "resourceType2", + "child_name_1": "name2", + "child_namespace_2": "Microsoft.Provider3", + "child_type_2": "resourceType3", + "child_name_2": "name3", + }, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg/' - 'providers/Microsoft.Provider1', - 'id_args': { - 'subscription': 'mySub', - 'resource_group': 'myRg', - 'namespace': 'Microsoft.Provider1' - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg/" "providers/Microsoft.Provider1", + "id_args": {"subscription": "mySub", "resource_group": "myRg", "namespace": "Microsoft.Provider1"}, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg', - 'id_args': { - 'subscription': 'mySub', - 'resource_group': 'myRg' - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg", + "id_args": {"subscription": "mySub", "resource_group": "myRg"}, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg/' - 'providers/Microsoft.Provider1/resourceType1/name1/resourceType2/' - 'name2/providers/Microsoft.Provider3', - 'id_args': { - 'subscription': 'mySub', - 'resource_group': 'myRg', - 'namespace': 'Microsoft.Provider1', - 'type': 'resourceType1', - 'name': 'name1', - 'child_type_1': 'resourceType2', - 'child_name_1': 'name2', - 'child_namespace_2': 'Microsoft.Provider3' - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg/" + "providers/Microsoft.Provider1/resourceType1/name1/resourceType2/" + "name2/providers/Microsoft.Provider3", + "id_args": { + "subscription": "mySub", + "resource_group": "myRg", + "namespace": "Microsoft.Provider1", + "type": "resourceType1", + "name": "name1", + "child_type_1": "resourceType2", + "child_name_1": "name2", + "child_namespace_2": "Microsoft.Provider3", + }, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg/' - 'providers/Microsoft.Provider1/resourceType1/name1', - 'id_args': { - 'subscription': 'mySub', - 'resource_group': 'myRg', - 'namespace': 'Microsoft.Provider1', - 'type': 'resourceType1', - 'name': 'name1', - 'child_type_1': None, - 'child_name_1': 'name2', - 'child_namespace_2': 'Microsoft.Provider3' - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg/" + "providers/Microsoft.Provider1/resourceType1/name1", + "id_args": { + "subscription": "mySub", + "resource_group": "myRg", + "namespace": "Microsoft.Provider1", + "type": "resourceType1", + "name": "name1", + "child_type_1": None, + "child_name_1": "name2", + "child_namespace_2": "Microsoft.Provider3", + }, }, { - 'resource_id': '/subscriptions/mySub/resourceGroups/myRg', - 'id_args': { - 'subscription': 'mySub', - 'resource_group': 'myRg' - } - } + "resource_id": "/subscriptions/mySub/resourceGroups/myRg", + "id_args": {"subscription": "mySub", "resource_group": "myRg"}, + }, ] for test in tests: - rsrc_id = resource_id(**test['id_args']) - self.assertEqual(rsrc_id.lower(), test['resource_id'].lower()) + rsrc_id = resource_id(**test["id_args"]) + self.assertEqual(rsrc_id.lower(), test["resource_id"].lower()) def test_is_resource_name(self): invalid_names = [ - '', - 'knights/ni', - 'spam&eggs', - 'i<3you', - 'a' * 261, + "", + "knights/ni", + "spam&eggs", + "i<3you", + "a" * 261, ] for test in invalid_names: assert not is_valid_resource_name(test) valid_names = [ - 'abc-123', - ' ', # no one said it had to be a good resource name. - 'a' * 260, + "abc-123", + " ", # no one said it had to be a good resource name. + "a" * 260, ] for test in valid_names: