diff --git a/sdk/core/azure-core/HISTORY.md b/sdk/core/azure-core/HISTORY.md index 2a96cb826c44..d855d0b93c9d 100644 --- a/sdk/core/azure-core/HISTORY.md +++ b/sdk/core/azure-core/HISTORY.md @@ -9,6 +9,14 @@ - Tracing: network span context is available with the TRACING_CONTEXT in pipeline response #7252 - Tracing: Span contract now has `kind`, `traceparent` and is a context manager #7252 +- SansIOHTTPPolicy methods can now be coroutines #7497 +- Add multipart/mixed support #7083: + + - HttpRequest now has a "set_multipart_mixed" method to set the parts of this request + - HttpRequest now has a "prepare_multipart_body" method to build final body. + - HttpResponse now has a "parts" method to return an iterator of parts + - AsyncHttpResponse now has a "parts" methods to return an async iterator of parts + - Note that multipart/MIXED is a Python 3.x only feature ### Bug fixes diff --git a/sdk/core/azure-core/README.md b/sdk/core/azure-core/README.md index c0f2ff5ca18b..100580d4ec6d 100644 --- a/sdk/core/azure-core/README.md +++ b/sdk/core/azure-core/README.md @@ -241,6 +241,11 @@ class HttpRequest(object): def set_bytes_body(self, data): """Set generic bytes as the body of the request.""" + + def set_multipart_mixed(self, *requests, **kwargs): + """Set requests for a multipart/mixed body. + Optionally apply "policies" in kwargs to each request. + """ ``` The HttpResponse object on the other hand will generally have a transport-specific derivative. @@ -285,6 +290,12 @@ class HttpResponse(object): and asynchronous generator. """ + def parts(self): + """An iterator of parts if content-type is multipart/mixed. + For the AsyncHttpResponse object this function will return + and asynchronous iterator. + """ + ``` ### PipelineRequest and PipelineResponse @@ -344,6 +355,8 @@ def on_exception(self, request): """ ``` +SansIOHTTPPolicy methods can be declared as coroutines, but then they can only be used with a AsyncPipeline. + Current provided sans IO policies include: ```python from azure.core.pipeline.policies import ( diff --git a/sdk/core/azure-core/azure/core/pipeline/__init__.py b/sdk/core/azure-core/azure/core/pipeline/__init__.py index b7b06deacbf6..a5458c4c450f 100644 --- a/sdk/core/azure-core/azure/core/pipeline/__init__.py +++ b/sdk/core/azure-core/azure/core/pipeline/__init__.py @@ -25,19 +25,22 @@ # -------------------------------------------------------------------------- import abc -from typing import (TypeVar, Any, Dict, Optional, Generic) +from typing import TypeVar, Generic try: ABC = abc.ABC -except AttributeError: # Python 2.7, abc exists, but not ABC - ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) # type: ignore +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") try: - from contextlib import AbstractContextManager # type: ignore #pylint: disable=unused-import -except ImportError: # Python <= 3.5 + from contextlib import ( # pylint: disable=unused-import + AbstractContextManager, + ) # type: ignore +except ImportError: # Python <= 3.5 + class AbstractContextManager(object): # type: ignore def __enter__(self): """Return `self` upon entering the runtime context.""" @@ -60,19 +63,20 @@ class PipelineContext(dict): :param transport: The HTTP transport type. :param kwargs: Developer-defined keyword arguments. """ - def __init__(self, transport, **kwargs): #pylint: disable=super-init-not-called + + def __init__(self, transport, **kwargs): # pylint: disable=super-init-not-called self.transport = transport self.options = kwargs - self._protected = ['transport', 'options'] + self._protected = ["transport", "options"] def __setitem__(self, key, item): if key in self._protected: - raise ValueError('Context value {} cannot be overwritten.'.format(key)) + raise ValueError("Context value {} cannot be overwritten.".format(key)) return super(PipelineContext, self).__setitem__(key, item) def __delitem__(self, key): if key in self._protected: - raise ValueError('Context value {} cannot be deleted.'.format(key)) + raise ValueError("Context value {} cannot be deleted.".format(key)) return super(PipelineContext, self).__delitem__(key) def clear(self): @@ -93,7 +97,7 @@ def pop(self, *args): """Removes specified key and returns the value. """ if args and args[0] in self._protected: - raise ValueError('Context value {} cannot be popped.'.format(args[0])) + raise ValueError("Context value {} cannot be popped.".format(args[0])) return super(PipelineContext, self).pop(*args) @@ -108,6 +112,7 @@ class PipelineRequest(Generic[HTTPRequestType]): :param context: Contains the context - data persisted between pipeline requests. :type context: ~azure.core.pipeline.PipelineContext """ + def __init__(self, http_request, context): # type: (HTTPRequestType, PipelineContext) -> None self.http_request = http_request @@ -131,6 +136,7 @@ class PipelineResponse(Generic[HTTPRequestType, HTTPResponseType]): :param context: Contains the context - data persisted between pipeline requests. :type context: ~azure.core.pipeline.PipelineContext """ + def __init__(self, http_request, http_response, context): # type: (HTTPRequestType, HTTPResponseType, PipelineContext) -> None self.http_request = http_request @@ -138,17 +144,13 @@ def __init__(self, http_request, http_response, context): self.context = context -from .base import Pipeline #pylint: disable=wrong-import-position +from .base import Pipeline # pylint: disable=wrong-import-position -__all__ = [ - 'Pipeline', - 'PipelineRequest', - 'PipelineResponse', - 'PipelineContext' -] +__all__ = ["Pipeline", "PipelineRequest", "PipelineResponse", "PipelineContext"] try: - from .base_async import AsyncPipeline #pylint: disable=unused-import - __all__.append('AsyncPipeline') + from .base_async import AsyncPipeline # pylint: disable=unused-import + + __all__.append("AsyncPipeline") except (SyntaxError, ImportError): pass # Asynchronous pipelines not supported. diff --git a/sdk/core/azure-core/azure/core/pipeline/base.py b/sdk/core/azure-core/azure/core/pipeline/base.py index e0f940249a3b..109827c818e4 100644 --- a/sdk/core/azure-core/azure/core/pipeline/base.py +++ b/sdk/core/azure-core/azure/core/pipeline/base.py @@ -25,10 +25,15 @@ # -------------------------------------------------------------------------- import logging -from typing import (TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, # pylint: disable=unused-import - Tuple, Callable, Iterator) -from azure.core.pipeline import AbstractContextManager, PipelineRequest, PipelineResponse, PipelineContext +from typing import Generic, TypeVar, List, Union, Any +from azure.core.pipeline import ( + AbstractContextManager, + PipelineRequest, + PipelineResponse, + PipelineContext, +) from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy + HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") HttpTransportType = TypeVar("HttpTransportType") @@ -40,8 +45,10 @@ def _await_result(func, *args, **kwargs): """If func returns an awaitable, raise that this runner can't handle it.""" result = func(*args, **kwargs) - if hasattr(result, '__await__'): - raise TypeError("Policy {} returned awaitable object in non-async pipeline.".format(func)) + if hasattr(result, "__await__"): + raise TypeError( + "Policy {} returned awaitable object in non-async pipeline.".format(func) + ) return result @@ -71,7 +78,7 @@ def send(self, request): _await_result(self._policy.on_request, request) try: response = self.next.send(request) - except Exception: #pylint: disable=broad-except + except Exception: # pylint: disable=broad-except if not _await_result(self._policy.on_exception, request): raise else: @@ -86,6 +93,7 @@ class _TransportRunner(HTTPPolicy): :param sender: The Http Transport instance. """ + def __init__(self, sender): # type: (HttpTransportType) -> None super(_TransportRunner, self).__init__() @@ -102,7 +110,7 @@ def send(self, request): return PipelineResponse( request.http_request, self._sender.send(request.http_request, **request.context.options), - context=request.context + context=request.context, ) @@ -123,29 +131,59 @@ class Pipeline(AbstractContextManager, Generic[HTTPRequestType, HTTPResponseType :dedent: 4 :caption: Builds the pipeline for synchronous transport. """ + def __init__(self, transport, policies=None): # type: (HttpTransportType, PoliciesType) -> None self._impl_policies = [] # type: List[HTTPPolicy] self._transport = transport # type: ignore - for policy in (policies or []): + for policy in policies or []: if isinstance(policy, SansIOHTTPPolicy): self._impl_policies.append(_SansIOHTTPPolicyRunner(policy)) elif policy: self._impl_policies.append(policy) - for index in range(len(self._impl_policies)-1): - self._impl_policies[index].next = self._impl_policies[index+1] + for index in range(len(self._impl_policies) - 1): + self._impl_policies[index].next = self._impl_policies[index + 1] if self._impl_policies: self._impl_policies[-1].next = _TransportRunner(self._transport) def __enter__(self): # type: () -> Pipeline - self._transport.__enter__() # type: ignore + self._transport.__enter__() # type: ignore return self def __exit__(self, *exc_details): # pylint: disable=arguments-differ self._transport.__exit__(*exc_details) + @staticmethod + def _prepare_multipart_mixed_request(request): + # type: (HTTPRequestType) -> None + """Will execute the multipart policies. + + Does nothing if "set_multipart_mixed" was never called. + """ + multipart_mixed_info = request.multipart_mixed_info # type: ignore + if not multipart_mixed_info: + return + + requests = multipart_mixed_info[0] # type: List[HTTPRequestType] + policies = multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] + + # Apply on_requests concurrently to all requests + import concurrent.futures + + def prepare_requests(req): + context = PipelineContext(None) + pipeline_request = PipelineRequest(req, context) + for policy in policies: + _await_result(policy.on_request, pipeline_request) + + with concurrent.futures.ThreadPoolExecutor() as executor: + # List comprehension to raise exceptions if happened + [ # pylint: disable=expression-not-assigned + _ for _ in executor.map(prepare_requests, requests) + ] + def run(self, request, **kwargs): # type: (HTTPRequestType, Any) -> PipelineResponse """Runs the HTTP Request through the chained policies. @@ -155,7 +193,15 @@ def run(self, request, **kwargs): :return: The PipelineResponse object :rtype: ~azure.core.pipeline.PipelineResponse """ + self._prepare_multipart_mixed_request(request) + request.prepare_multipart_body() # type: ignore context = PipelineContext(self._transport, **kwargs) - pipeline_request = PipelineRequest(request, context) # type: PipelineRequest[HTTPRequestType] - first_node = self._impl_policies[0] if self._impl_policies else _TransportRunner(self._transport) + pipeline_request = PipelineRequest( + request, context + ) # type: PipelineRequest[HTTPRequestType] + first_node = ( + self._impl_policies[0] + if self._impl_policies + else _TransportRunner(self._transport) + ) return first_node.send(pipeline_request) # type: ignore diff --git a/sdk/core/azure-core/azure/core/pipeline/base_async.py b/sdk/core/azure-core/azure/core/pipeline/base_async.py index f8eea3706ef9..87ffa563cabd 100644 --- a/sdk/core/azure-core/azure/core/pipeline/base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/base_async.py @@ -32,12 +32,17 @@ AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") -ImplPoliciesType = List[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]] #pylint: disable=unsubscriptable-object +ImplPoliciesType = List[ + AsyncHTTPPolicy[ # pylint: disable=unsubscriptable-object + HTTPRequestType, AsyncHTTPResponseType + ] +] AsyncPoliciesType = List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]] try: from contextlib import AbstractAsyncContextManager # type: ignore -except ImportError: # Python <= 3.7 +except ImportError: # Python <= 3.7 + class AbstractAsyncContextManager(object): # type: ignore async def __aenter__(self): """Return `self` upon entering the runtime context.""" @@ -52,13 +57,15 @@ async def __aexit__(self, exc_type, exc_value, traceback): async def _await_result(func, *args, **kwargs): """If func returns an awaitable, await it.""" result = func(*args, **kwargs) - if hasattr(result, '__await__'): + if hasattr(result, "__await__"): # type ignore on await: https://github.com/python/mypy/issues/7587 return await result # type: ignore return result -class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): #pylint: disable=unsubscriptable-object +class _SansIOAsyncHTTPPolicyRunner( + AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType] +): # pylint: disable=unsubscriptable-object """Async implementation of the SansIO policy. Modifies the request and sends to the next policy in the chain. @@ -82,7 +89,7 @@ async def send(self, request: PipelineRequest) -> PipelineResponse: await _await_result(self._policy.on_request, request) try: response = await self.next.send(request) # type: ignore - except Exception: #pylint: disable=broad-except + except Exception: # pylint: disable=broad-except if not await _await_result(self._policy.on_exception, request): raise else: @@ -90,13 +97,16 @@ async def send(self, request: PipelineRequest) -> PipelineResponse: return response -class _AsyncTransportRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): #pylint: disable=unsubscriptable-object +class _AsyncTransportRunner( + AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType] +): # pylint: disable=unsubscriptable-object """Async Transport runner. Uses specified HTTP transport type to send request and returns response. :param sender: The async Http Transport instance. """ + def __init__(self, sender) -> None: super(_AsyncTransportRunner, self).__init__() self._sender = sender @@ -112,11 +122,13 @@ async def send(self, request): return PipelineResponse( request.http_request, await self._sender.send(request.http_request, **request.context.options), - request.context + request.context, ) -class AsyncPipeline(AbstractAsyncContextManager, Generic[HTTPRequestType, AsyncHTTPResponseType]): +class AsyncPipeline( + AbstractAsyncContextManager, Generic[HTTPRequestType, AsyncHTTPResponseType] +): """Async pipeline implementation. This is implemented as a context manager, that will activate the context @@ -138,13 +150,13 @@ def __init__(self, transport, policies: AsyncPoliciesType = None) -> None: self._impl_policies = [] # type: ImplPoliciesType self._transport = transport - for policy in (policies or []): + for policy in policies or []: if isinstance(policy, SansIOHTTPPolicy): self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy)) elif policy: self._impl_policies.append(policy) - for index in range(len(self._impl_policies)-1): - self._impl_policies[index].next = self._impl_policies[index+1] + for index in range(len(self._impl_policies) - 1): + self._impl_policies[index].next = self._impl_policies[index + 1] if self._impl_policies: self._impl_policies[-1].next = _AsyncTransportRunner(self._transport) @@ -155,14 +167,39 @@ def __exit__(self, exc_type, exc_val, exc_tb): # __exit__ should exist in pair with __enter__ but never executed pass # pragma: no cover - async def __aenter__(self) -> 'AsyncPipeline': + async def __aenter__(self) -> "AsyncPipeline": await self._transport.__aenter__() return self async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ await self._transport.__aexit__(*exc_details) - async def run(self, request: PipelineRequest[HTTPRequestType], **kwargs: Any): + async def _prepare_multipart_mixed_request(self, request): + # type: (HTTPRequestType) -> None + """Will execute the multipart policies. + + Does nothing if "set_multipart_mixed" was never called. + """ + multipart_mixed_info = request.multipart_mixed_info # type: ignore + if not multipart_mixed_info: + return + + requests = multipart_mixed_info[0] # type: List[HTTPRequestType] + policies = multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] + + async def prepare_requests(req): + context = PipelineContext(None) + pipeline_request = PipelineRequest(req, context) + for policy in policies: + await _await_result(policy.on_request, pipeline_request) + + # Not happy to make this code asyncio specific, but that's multipart only for now + # If we need trio and multipart, let's reinvesitgate that later + import asyncio + + await asyncio.gather(*[prepare_requests(req) for req in requests]) + + async def run(self, request: HTTPRequestType, **kwargs: Any): """Runs the HTTP Request through the chained policies. :param request: The HTTP request object. @@ -170,7 +207,13 @@ async def run(self, request: PipelineRequest[HTTPRequestType], **kwargs: Any): :return: The PipelineResponse object. :rtype: ~azure.core.pipeline.PipelineResponse """ + await self._prepare_multipart_mixed_request(request) + request.prepare_multipart_body() # type: ignore context = PipelineContext(self._transport, **kwargs) pipeline_request = PipelineRequest(request, context) - first_node = self._impl_policies[0] if self._impl_policies else _AsyncTransportRunner(self._transport) + first_node = ( + self._impl_policies[0] + if self._impl_policies + else _AsyncTransportRunner(self._transport) + ) return await first_node.send(pipeline_request) # type: ignore diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/base.py b/sdk/core/azure-core/azure/core/pipeline/transport/base.py index 4fb45d3b9b0d..ac4951cf7d0b 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/base.py @@ -25,28 +25,56 @@ # -------------------------------------------------------------------------- from __future__ import absolute_import import abc +from email.message import Message + +try: + from email import message_from_bytes as message_parser +except ImportError: # 2.7 + from email import message_from_string as message_parser # type: ignore +from io import BytesIO import json import logging import os import time + try: binary_type = str - from urlparse import urlparse # type: ignore + from urlparse import urlparse # type: ignore except ImportError: - binary_type = bytes # type: ignore + binary_type = bytes # type: ignore from urllib.parse import urlparse import xml.etree.ElementTree as ET -from typing import (TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, # pylint: disable=unused-import - Optional, Tuple, Callable, Iterator) - -# This file is NOT using any "requests" HTTP implementation -# However, the CaseInsensitiveDict is handy. -# If one day we reach the point where "requests" can be skip totally, -# might provide our own implementation -from requests.structures import CaseInsensitiveDict -from azure.core.pipeline import ABC, AbstractContextManager, PipelineRequest, PipelineResponse - +from typing import ( + TYPE_CHECKING, + Generic, + TypeVar, + cast, + IO, + List, + Union, + Any, + Mapping, + Dict, + Optional, + Tuple, + Iterator, +) + +from six.moves.http_client import HTTPConnection, HTTPResponse as _HTTPResponse + +from azure.core.pipeline import ( + ABC, + AbstractContextManager, + PipelineRequest, + PipelineResponse, + PipelineContext, +) +from ..base import _await_result + + +if TYPE_CHECKING: + from ..policies import SansIOHTTPPolicy HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") @@ -55,6 +83,32 @@ _LOGGER = logging.getLogger(__name__) +def _case_insensitive_dict(*args, **kwargs): + """Return a case-insensitive dict from a structure that a dict would have accepted. + + Rational is I don't want to re-implement this, but I don't want + to assume "requests" or "aiohttp" are installed either. + So I use the one from "requests" or the one from "aiohttp" ("multidict") + If one day this library is used in an HTTP context without "requests" nor "aiohttp" installed, + we can add "multidict" as a dependency or re-implement our own. + """ + try: + from requests.structures import CaseInsensitiveDict + + return CaseInsensitiveDict(*args, **kwargs) + except ImportError: + pass + try: + # multidict is installed by aiohttp + from multidict import CIMultiDict + + return CIMultiDict(*args, **kwargs) + except ImportError: + raise ValueError( + "Neither 'requests' or 'multidict' are installed and no case-insensitive dict impl have been found" + ) + + def _format_url_section(template, **kwargs): components = template.split("/") while components: @@ -62,7 +116,9 @@ def _format_url_section(template, **kwargs): return template.format(**kwargs) except KeyError as key: formatted_components = template.split("/") - components = [c for c in formatted_components if "{{{}}}".format(key.args[0]) not in c] + components = [ + c for c in formatted_components if "{{{}}}".format(key.args[0]) not in c + ] template = "/".join(components) # No URL sections left - returning None @@ -77,11 +133,42 @@ def _urljoin(base_url, stub_url): :rtype: str """ parsed = urlparse(base_url) - parsed = parsed._replace(path=parsed.path + '/' + stub_url) + parsed = parsed._replace(path=parsed.path + "/" + stub_url) return parsed.geturl() -class HttpTransport(AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignore +class _HTTPSerializer(HTTPConnection, object): + """Hacking the stdlib HTTPConnection to serialize HTTP request as strings. + """ + + def __init__(self, *args, **kwargs): + self.buffer = b"" + kwargs.setdefault("host", "fakehost") + super(_HTTPSerializer, self).__init__(*args, **kwargs) + + def putheader(self, header, *values): + if header in ["Host", "Accept-Encoding"]: + return + super(_HTTPSerializer, self).putheader(header, *values) + + def send(self, data): + self.buffer += data + + +def _serialize_request(http_request): + serializer = _HTTPSerializer() + serializer.request( + method=http_request.method, + url=http_request.url, + body=http_request.body, + headers=http_request.headers, + ) + return serializer.buffer + + +class HttpTransport( + AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType] +): # type: ignore """An http sender ABC. """ @@ -99,7 +186,7 @@ def open(self): def close(self): """Close the session if it is not externally owned.""" - def sleep(self, duration): #pylint: disable=no-self-use + def sleep(self, duration): # pylint: disable=no-self-use time.sleep(duration) @@ -115,23 +202,25 @@ class HttpRequest(object): :param data: Body to be sent. :type data: bytes or str. """ + def __init__(self, method, url, headers=None, files=None, data=None): # type: (str, str, Mapping[str, str], Any, Any) -> None self.method = method self.url = url - self.headers = CaseInsensitiveDict(headers) + self.headers = _case_insensitive_dict(headers) self.files = files self.data = data + self.multipart_mixed_info = None # type: Optional[Tuple] def __repr__(self): - return '' % (self.method) + return "" % (self.method) @property def query(self): """The query parameters of the request as a dict.""" query = urlparse(self.url).query if query: - return {p[0]: p[-1] for p in [p.partition('=') for p in query.split('&')]} + return {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} return {} @property @@ -152,11 +241,11 @@ def _format_data(data): :param data: The request field data. :type data: str or file-like object. """ - if hasattr(data, 'read'): + if hasattr(data, "read"): data = cast(IO, data) data_name = None try: - if data.name[0] != '<' and data.name[-1] != '>': + if data.name[0] != "<" and data.name[-1] != ">": data_name = os.path.basename(data.name) except (AttributeError, TypeError): pass @@ -173,14 +262,13 @@ def format_parameters(self, params): """ query = urlparse(self.url).query if query: - self.url = self.url.partition('?')[0] + self.url = self.url.partition("?")[0] existing_params = { - p[0]: p[-1] - for p in [p.partition('=') for p in query.split('&')] + p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")] } params.update(existing_params) query_params = ["{}={}".format(k, v) for k, v in params.items()] - query = '?' + '&'.join(query_params) + query = "?" + "&".join(query_params) self.url = self.url + query def set_streamed_data_body(self, data): @@ -188,9 +276,12 @@ def set_streamed_data_body(self, data): :param data: The request field data. """ - if not isinstance(data, binary_type) and \ - not any(hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"]): - raise TypeError("A streamable data source must be an open file-like object or iterable.") + if not isinstance(data, binary_type) and not any( + hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"] + ): + raise TypeError( + "A streamable data source must be an open file-like object or iterable." + ) self.data = data self.files = None @@ -204,7 +295,7 @@ def set_xml_body(self, data): else: bytes_data = ET.tostring(data, encoding="utf8") self.data = bytes_data.replace(b"encoding='utf8'", b"encoding='utf-8'") - self.headers['Content-Length'] = str(len(self.data)) + self.headers["Content-Length"] = str(len(self.data)) self.files = None def set_json_body(self, data): @@ -216,7 +307,7 @@ def set_json_body(self, data): self.data = None else: self.data = json.dumps(data) - self.headers['Content-Length'] = str(len(self.data)) + self.headers["Content-Length"] = str(len(self.data)) self.files = None def set_formdata_body(self, data=None): @@ -226,13 +317,15 @@ def set_formdata_body(self, data=None): """ if data is None: data = {} - content_type = self.headers.pop('Content-Type', None) if self.headers else None + content_type = self.headers.pop("Content-Type", None) if self.headers else None - if content_type and content_type.lower() == 'application/x-www-form-urlencoded': + if content_type and content_type.lower() == "application/x-www-form-urlencoded": self.data = {f: d for f, d in data.items() if d is not None} self.files = None - else: # Assume "multipart/form-data" - self.files = {f: self._format_data(d) for f, d in data.items() if d is not None} + else: # Assume "multipart/form-data" + self.files = { + f: self._format_data(d) for f, d in data.items() if d is not None + } self.data = None def set_bytes_body(self, data): @@ -241,10 +334,85 @@ def set_bytes_body(self, data): :param data: The request field data. """ if data: - self.headers['Content-Length'] = str(len(data)) + self.headers["Content-Length"] = str(len(data)) self.data = data self.files = None + def set_multipart_mixed(self, *requests, **kwargs): + # type: (HttpRequest, Any) -> None + """Set the part of a multipart/mixed. + + Only support args for now are HttpRequest objects. + + boundary is optional, and one will be generated if you don't provide one. + Note that no verification are made on the boundary, this is considered advanced + enough so you know how to respect RFC1341 7.2.1 and provide a correct boundary. + + kwargs: + - policies: SansIOPolicy to apply at preparation time + - boundary: Optional boundary + + :param requests: HttpRequests object + """ + self.multipart_mixed_info = ( + requests, + kwargs.pop("policies", []), + kwargs.pop("boundary", []), + ) + + def prepare_multipart_body(self): + # type: () -> None + """Will prepare the body of this request according to the multipart information. + + This call assumes the on_request policies have been applied already in their + correct context (sync/async) + + Does nothing if "set_multipart_mixed" was never called. + """ + if not self.multipart_mixed_info: + return + + requests = self.multipart_mixed_info[0] # type: List[HttpRequest] + boundary = self.multipart_mixed_info[2] # type: Optional[str] + + # Update the main request with the body + main_message = Message() + main_message.add_header("Content-Type", "multipart/mixed") + if boundary: + main_message.set_boundary(boundary) + for i, req in enumerate(requests): + part_message = Message() + part_message.add_header("Content-Type", "application/http") + part_message.add_header("Content-Transfer-Encoding", "binary") + part_message.add_header("Content-ID", str(i)) + part_message.set_payload(req.serialize()) + main_message.attach(part_message) + + try: + from email.policy import HTTP + + full_message = main_message.as_bytes(policy=HTTP) + eol = b"\r\n" + except ImportError: # Python 2.7 + # Right now we decide to not support Python 2.7 on serialization, since + # it doesn't serialize a valid HTTP request (and our main scenario Storage refuses it) + raise NotImplementedError( + "Multipart request are not supported on Python 2.7" + ) + # full_message = main_message.as_string() + # eol = b'\n' + _, _, body = full_message.split(eol, 2) + self.set_bytes_body(body) + self.headers["Content-Type"] = ( + "multipart/mixed; boundary=" + main_message.get_boundary() + ) + + def serialize(self): + # type: () -> bytes + """Serialize this request using application/http spec. + """ + return _serialize_request(self) + class _HttpResponseBase(object): """Represent a HTTP response. @@ -262,6 +430,7 @@ class _HttpResponseBase(object): :param str content_type: The content type. :param int block_size: Defaults to 4096 bytes. """ + def __init__(self, request, internal_response, block_size=None): # type: (HttpRequest, Any, Optional[int]) -> None self.request = request @@ -272,11 +441,11 @@ def __init__(self, request, internal_response, block_size=None): self.content_type = None # type: Optional[str] self.block_size = block_size or 4096 # Default to same as Requests - def body(self): # type: () -> bytes """Return the whole body as bytes in memory. """ + raise NotImplementedError() def text(self, encoding=None): # type: (str) -> str @@ -287,8 +456,47 @@ def text(self, encoding=None): """ return self.body().decode(encoding or "utf-8") + def _get_raw_parts(self, http_response_type=None): + # type (Optional[Type[_HttpResponseBase]]) -> Iterator[HttpResponse] + """Assuming this body is multipart, return the iterator or parts. -class HttpResponse(_HttpResponseBase): + If parts are application/http use http_response_type or HttpClientTransportResponse + as enveloppe. + """ + if http_response_type is None: + http_response_type = HttpClientTransportResponse + + body_as_bytes = self.body() + # In order to use email.message parser, I need full HTTP bytes. Faking something to make the parser happy + http_body = ( + b"Content-Type: " + + self.content_type.encode("ascii") + + b"\r\n\r\n" + + body_as_bytes + ) + + message = message_parser(http_body) # type: Message + + # Rebuild an HTTP response from pure string + requests = self.request.multipart_mixed_info[0] # type: List[HttpRequest] + responses = [] + for request, raw_reponse in zip(requests, message.get_payload()): + if raw_reponse.get_content_type() == "application/http": + responses.append( + _deserialize_response( + raw_reponse.get_payload(decode=True), + request, + http_response_type=http_response_type, + ) + ) + else: + raise ValueError( + "Multipart doesn't support part other than application/http for now" + ) + return responses + + +class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method def stream_download(self, pipeline): # type: (PipelineType) -> Iterator[bytes] """Generator for streaming request body data. @@ -297,6 +505,100 @@ def stream_download(self, pipeline): is supported. """ + def parts(self): + # type: () -> Iterator[HttpResponse] + """Assuming the content-type is multipart/mixed, will return the parts as an iterator. + + :rtype: iterator + :raises ValueError: If the content is not multipart/mixed + """ + if not self.content_type or not self.content_type.startswith("multipart/mixed"): + raise ValueError( + "You can't get parts if the response is not multipart/mixed" + ) + + responses = self._get_raw_parts() + if self.request.multipart_mixed_info: + policies = self.request.multipart_mixed_info[1] # type: List[SansIOHTTPPolicy] + + # Apply on_response concurrently to all requests + import concurrent.futures + + def parse_responses(response): + http_request = response.request + context = PipelineContext(None) + pipeline_request = PipelineRequest(http_request, context) + pipeline_response = PipelineResponse( + http_request, response, context=context + ) + + for policy in policies: + _await_result(policy.on_response, pipeline_request, pipeline_response) + + with concurrent.futures.ThreadPoolExecutor() as executor: + # List comprehension to raise exceptions if happened + [ # pylint: disable=expression-not-assigned + _ for _ in executor.map(parse_responses, responses) + ] + + return responses + + +class _HttpClientTransportResponse(_HttpResponseBase): + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + + :param HttpRequest request: The request. + :param httpclient_response: The object returned from an HTTP(S)Connection from http.client + """ + + def __init__(self, request, httpclient_response): + super(_HttpClientTransportResponse, self).__init__(request, httpclient_response) + self.status_code = httpclient_response.status + self.headers = _case_insensitive_dict(httpclient_response.getheaders()) + self.reason = httpclient_response.reason + self.content_type = self.headers.get("Content-Type") + self.data = None + + def body(self): + if self.data is None: + self.data = self.internal_response.read() + return self.data + + +class HttpClientTransportResponse(_HttpClientTransportResponse, HttpResponse): + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + + :param HttpRequest request: The request. + :param httpclient_response: The object returned from an HTTP(S)Connection from http.client + """ + + +class BytesIOSocket(object): + """Mocking the "makefile" of socket for HTTPResponse. + + This can be used to create a http.client.HTTPResponse object + based on bytes and not a real socket. + """ + + def __init__(self, bytes_data): + self.bytes_data = bytes_data + + def makefile(self, *_): + return BytesIO(self.bytes_data) + + +def _deserialize_response( + http_response_as_bytes, http_request, http_response_type=HttpClientTransportResponse +): + local_socket = BytesIOSocket(http_response_as_bytes) + response = _HTTPResponse(local_socket, method=http_request.method) + response.begin() + return http_response_type(http_request, response) + class PipelineClientBase(object): """Base class for pipeline clients. @@ -308,14 +610,15 @@ def __init__(self, base_url): self._base_url = base_url def _request( - self, method, # type: str - url, # type: str - params, # type: Optional[Dict[str, str]] - headers, # type: Optional[Dict[str, str]] - content, # type: Any - form_content, # type: Optional[Dict[str, Any]] - stream_content, # type: Any - ): + self, + method, # type: str + url, # type: str + params, # type: Optional[Dict[str, str]] + headers, # type: Optional[Dict[str, str]] + content, # type: Any + form_content, # type: Optional[Dict[str, Any]] + stream_content, # type: Any + ): # type: (...) -> HttpRequest """Create HttpRequest object. @@ -362,20 +665,21 @@ def format_url(self, url_template, **kwargs): if url: parsed = urlparse(url) if not parsed.scheme or not parsed.netloc: - url = url.lstrip('/') - base = self._base_url.format(**kwargs).rstrip('/') + url = url.lstrip("/") + base = self._base_url.format(**kwargs).rstrip("/") url = _urljoin(base, url) else: url = self._base_url.format(**kwargs) return url def get( - self, url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None # type: Optional[Dict[str, Any]] - ): + self, + url, # type: str + params=None, # type: Optional[Dict[str, str]] + headers=None, # type: Optional[Dict[str, str]] + content=None, # type: Any + form_content=None, # type: Optional[Dict[str, Any]] + ): # type: (...) -> HttpRequest """Create a GET request object. @@ -386,18 +690,21 @@ def get( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request('GET', url, params, headers, content, form_content, None) - request.method = 'GET' + request = self._request( + "GET", url, params, headers, content, form_content, None + ) + request.method = "GET" return request def put( - self, url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - stream_content=None # type: Any - ): + self, + url, # type: str + params=None, # type: Optional[Dict[str, str]] + headers=None, # type: Optional[Dict[str, str]] + content=None, # type: Any + form_content=None, # type: Optional[Dict[str, Any]] + stream_content=None, # type: Any + ): # type: (...) -> HttpRequest """Create a PUT request object. @@ -408,17 +715,20 @@ def put( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request('PUT', url, params, headers, content, form_content, stream_content) + request = self._request( + "PUT", url, params, headers, content, form_content, stream_content + ) return request def post( - self, url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - stream_content=None # type: Any - ): + self, + url, # type: str + params=None, # type: Optional[Dict[str, str]] + headers=None, # type: Optional[Dict[str, str]] + content=None, # type: Any + form_content=None, # type: Optional[Dict[str, Any]] + stream_content=None, # type: Any + ): # type: (...) -> HttpRequest """Create a POST request object. @@ -429,17 +739,20 @@ def post( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request('POST', url, params, headers, content, form_content, stream_content) + request = self._request( + "POST", url, params, headers, content, form_content, stream_content + ) return request def head( - self, url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - stream_content=None # type: Any - ): + self, + url, # type: str + params=None, # type: Optional[Dict[str, str]] + headers=None, # type: Optional[Dict[str, str]] + content=None, # type: Any + form_content=None, # type: Optional[Dict[str, Any]] + stream_content=None, # type: Any + ): # type: (...) -> HttpRequest """Create a HEAD request object. @@ -450,17 +763,20 @@ def head( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request('HEAD', url, params, headers, content, form_content, stream_content) + request = self._request( + "HEAD", url, params, headers, content, form_content, stream_content + ) return request def patch( - self, url, # type: str - params=None, # type: Optional[Dict[str, str]] - headers=None, # type: Optional[Dict[str, str]] - content=None, # type: Any - form_content=None, # type: Optional[Dict[str, Any]] - stream_content=None # type: Any - ): + self, + url, # type: str + params=None, # type: Optional[Dict[str, str]] + headers=None, # type: Optional[Dict[str, str]] + content=None, # type: Any + form_content=None, # type: Optional[Dict[str, Any]] + stream_content=None, # type: Any + ): # type: (...) -> HttpRequest """Create a PATCH request object. @@ -471,7 +787,9 @@ def patch( :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request('PATCH', url, params, headers, content, form_content, stream_content) + request = self._request( + "PATCH", url, params, headers, content, form_content, stream_content + ) return request def delete(self, url, params=None, headers=None, content=None, form_content=None): @@ -485,7 +803,9 @@ def delete(self, url, params=None, headers=None, content=None, form_content=None :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request('DELETE', url, params, headers, content, form_content, None) + request = self._request( + "DELETE", url, params, headers, content, form_content, None + ) return request def merge(self, url, params=None, headers=None, content=None, form_content=None): @@ -499,5 +819,7 @@ def merge(self, url, params=None, headers=None, content=None, form_content=None) :return: An HttpRequest object :rtype: ~azure.core.pipeline.transport.HttpRequest """ - request = self._request('MERGE', url, params, headers, content, form_content, None) + request = self._request( + "MERGE", url, params, headers, content, form_content, None + ) return request diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/base_async.py index 8a3ed3a522f6..aa69ea619187 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/base_async.py @@ -26,15 +26,22 @@ import asyncio import abc - -from typing import Any, List, Union, Callable, AsyncIterator, Optional, Generic, TypeVar -from azure.core.pipeline import PipelineRequest, PipelineResponse, Pipeline -from azure.core.pipeline.policies import SansIOHTTPPolicy -from .base import _HttpResponseBase +from collections.abc import AsyncIterator + +from typing import AsyncIterator as AsyncIteratorType, Generic, TypeVar +from .base import ( + _HttpResponseBase, + _HttpClientTransportResponse, + PipelineContext, + PipelineRequest, + PipelineResponse, +) +from ..base_async import _await_result try: from contextlib import AbstractAsyncContextManager # type: ignore -except ImportError: # Python <= 3.7 +except ImportError: # Python <= 3.7 + class AbstractAsyncContextManager(object): # type: ignore async def __aenter__(self): """Return `self` upon entering the runtime context.""" @@ -65,12 +72,61 @@ def _iterate_response_content(iterator): raise _ResponseStopIteration() -class AsyncHttpResponse(_HttpResponseBase): +class _PartGenerator(AsyncIterator): + """Until parts is a real async iterator, wrap the sync call. + + :param parts: An iterable of parts + """ + + def __init__(self, response: "AsyncHttpResponse") -> None: + self._response = response + self._parts = None + + async def _parse_response(self): + responses = self._response._get_raw_parts( # pylint: disable=protected-access + http_response_type=AsyncHttpClientTransportResponse + ) + if self._response.request.multipart_mixed_info: + policies = self._response.request.multipart_mixed_info[ + 1 + ] # type: List[SansIOHTTPPolicy] + + async def parse_responses(response): + http_request = response.request + context = PipelineContext(None) + pipeline_request = PipelineRequest(http_request, context) + pipeline_response = PipelineResponse( + http_request, response, context=context + ) + + for policy in policies: + await _await_result( + policy.on_response, pipeline_request, pipeline_response + ) + + # Not happy to make this code asyncio specific, but that's multipart only for now + # If we need trio and multipart, let's reinvesitgate that later + await asyncio.gather(*[parse_responses(res) for res in responses]) + + return responses + + async def __anext__(self): + if not self._parts: + self._parts = iter(await self._parse_response()) + + try: + return next(self._parts) + except StopIteration: + raise StopAsyncIteration() + + +class AsyncHttpResponse(_HttpResponseBase): # pylint: disable=abstract-method """An AsyncHttpResponse ABC. Allows for the asynchronous streaming of data from the response. """ - def stream_download(self, pipeline) -> AsyncIterator[bytes]: + + def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. Should be implemented by sub-classes if streaming download @@ -80,8 +136,35 @@ def stream_download(self, pipeline) -> AsyncIterator[bytes]: :type pipeline: azure.core.pipeline """ + def parts(self) -> AsyncIterator: + """Assuming the content-type is multipart/mixed, will return the parts as an async iterator. + + :rtype: AsyncIterator + :raises ValueError: If the content is not multipart/mixed + """ + if not self.content_type or not self.content_type.startswith("multipart/mixed"): + raise ValueError( + "You can't get parts if the response is not multipart/mixed" + ) + + return _PartGenerator(self) + + +class AsyncHttpClientTransportResponse(_HttpClientTransportResponse, AsyncHttpResponse): + """Create a HTTPResponse from an http.client response. + + Body will NOT be read by the constructor. Call "body()" to load the body in memory if necessary. + + :param HttpRequest request: The request. + :param httpclient_response: The object returned from an HTTP(S)Connection from http.client + """ + -class AsyncHttpTransport(AbstractAsyncContextManager, abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]): +class AsyncHttpTransport( + AbstractAsyncContextManager, + abc.ABC, + Generic[HTTPRequestType, AsyncHTTPResponseType], +): """An http sender ABC. """ diff --git a/sdk/core/azure-core/setup.py b/sdk/core/azure-core/setup.py index bdab06be195c..d53d718637a2 100644 --- a/sdk/core/azure-core/setup.py +++ b/sdk/core/azure-core/setup.py @@ -62,6 +62,7 @@ ]), install_requires=[ 'requests>=2.18.4', + 'six>=1.6', ], extras_require={ ":python_version<'3.0'": ['azure-nspkg'], diff --git a/sdk/core/azure-core/tests/azure_core_asynctests/test_basic_transport.py b/sdk/core/azure-core/tests/azure_core_asynctests/test_basic_transport.py new file mode 100644 index 000000000000..8557f0b679cb --- /dev/null +++ b/sdk/core/azure-core/tests/azure_core_asynctests/test_basic_transport.py @@ -0,0 +1,273 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from six.moves.http_client import HTTPConnection +import time + +try: + from unittest import mock +except ImportError: + import mock + +from azure.core.pipeline.transport import HttpRequest, AsyncHttpResponse, AsyncHttpTransport +from azure.core.pipeline.policies import HeadersPolicy +from azure.core.pipeline import AsyncPipeline + +import pytest + + +@pytest.mark.asyncio +async def test_multipart_send(): + + # transport = mock.MagicMock(spec=AsyncHttpTransport) + # MagicMock support async cxt manager only after 3.8 + # https://github.com/python/cpython/pull/9296 + + class MockAsyncHttpTransport(AsyncHttpTransport): + async def __aenter__(self): return self + async def __aexit__(self, *args): pass + async def open(self): pass + async def close(self): pass + async def send(self, request, **kwargs): pass + + transport = MockAsyncHttpTransport() + + class RequestPolicy(object): + async def on_request(self, request): + # type: (PipelineRequest) -> None + request.http_request.headers['x-ms-date'] = 'Thu, 14 Jun 2018 16:46:54 GMT' + + req0 = HttpRequest("DELETE", "/container0/blob0") + req1 = HttpRequest("DELETE", "/container1/blob1") + + request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed( + req0, + req1, + policies=[RequestPolicy()], + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" # Fix it so test are deterministic + ) + + async with AsyncPipeline(transport) as pipeline: + await pipeline.run(request) + + assert request.body == ( + b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' + b'Content-Type: application/http\r\n' + b'Content-Transfer-Encoding: binary\r\n' + b'Content-ID: 0\r\n' + b'\r\n' + b'DELETE /container0/blob0 HTTP/1.1\r\n' + b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' + b'\r\n' + b'\r\n' + b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' + b'Content-Type: application/http\r\n' + b'Content-Transfer-Encoding: binary\r\n' + b'Content-ID: 1\r\n' + b'\r\n' + b'DELETE /container1/blob1 HTTP/1.1\r\n' + b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' + b'\r\n' + b'\r\n' + b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + ) + + +@pytest.mark.asyncio +async def test_multipart_receive(): + + class MockResponse(AsyncHttpResponse): + def __init__(self, request, body, content_type): + super(MockResponse, self).__init__(request, None) + self._body = body + self.content_type = content_type + + def body(self): + return self._body + + class ResponsePolicy(object): + def on_response(self, request, response): + # type: (PipelineRequest, PipelineResponse) -> None + response.http_response.headers['x-ms-fun'] = 'true' + + class AsyncResponsePolicy(object): + async def on_response(self, request, response): + # type: (PipelineRequest, PipelineResponse) -> None + response.http_response.headers['x-ms-async-fun'] = 'true' + + req0 = HttpRequest("DELETE", "/container0/blob0") + req1 = HttpRequest("DELETE", "/container1/blob1") + + request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed( + req0, + req1, + policies=[ResponsePolicy(), AsyncResponsePolicy()] + ) + + body_as_str = ( + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 0\r\n" + "\r\n" + "HTTP/1.1 202 Accepted\r\n" + "x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + "x-ms-version: 2018-11-09\r\n" + "\r\n" + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 2\r\n" + "\r\n" + "HTTP/1.1 404 The specified blob does not exist.\r\n" + "x-ms-error-code: BlobNotFound\r\n" + "x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2852\r\n" + "x-ms-version: 2018-11-09\r\n" + "Content-Length: 216\r\n" + "Content-Type: application/xml\r\n" + "\r\n" + '\r\n' + "BlobNotFoundThe specified blob does not exist.\r\n" + "RequestId:778fdc83-801e-0000-62ff-0334671e2852\r\n" + "Time:2018-06-14T16:46:54.6040685Z\r\n" + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" + ) + + response = MockResponse( + request, + body_as_str.encode('ascii'), + "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + ) + + parts = [] + async for part in response.parts(): + parts.append(part) + + assert len(parts) == 2 + + res0 = parts[0] + assert res0.status_code == 202 + assert res0.headers['x-ms-fun'] == 'true' + assert res0.headers['x-ms-async-fun'] == 'true' + + res1 = parts[1] + assert res1.status_code == 404 + assert res1.headers['x-ms-fun'] == 'true' + assert res1.headers['x-ms-async-fun'] == 'true' + + +@pytest.mark.asyncio +async def test_multipart_receive_with_bom(): + + req0 = HttpRequest("DELETE", "/container0/blob0") + + request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(req0) + + class MockResponse(AsyncHttpResponse): + def __init__(self, request, body, content_type): + super(MockResponse, self).__init__(request, None) + self._body = body + self.content_type = content_type + + def body(self): + return self._body + + body_as_bytes = ( + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\n" + b"Content-Type: application/http\n" + b"Content-Transfer-Encoding: binary\n" + b"Content-ID: 0\n" + b'\r\n' + b'HTTP/1.1 400 One of the request inputs is not valid.\r\n' + b'Content-Length: 220\r\n' + b'Content-Type: application/xml\r\n' + b'Server: Windows-Azure-Blob/1.0\r\n' + b'\r\n' + b'\xef\xbb\xbf\nInvalidInputOne' + b'of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" + ) + + response = MockResponse( + request, + body_as_bytes, + "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + ) + + parts = [] + async for part in response.parts(): + parts.append(part) + assert len(parts) == 1 + + res0 = parts[0] + assert res0.status_code == 400 + assert res0.body().startswith(b'\xef\xbb\xbf') + + +@pytest.mark.asyncio +async def test_recursive_multipart_receive(): + req0 = HttpRequest("DELETE", "/container0/blob0") + internal_req0 = HttpRequest("DELETE", "/container0/blob0") + req0.set_multipart_mixed(internal_req0) + + request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(req0) + + class MockResponse(AsyncHttpResponse): + def __init__(self, request, body, content_type): + super(MockResponse, self).__init__(request, None) + self._body = body + self.content_type = content_type + + def body(self): + return self._body + + internal_body_as_str = ( + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 0\r\n" + "\r\n" + "HTTP/1.1 400 Accepted\r\n" + "x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + "x-ms-version: 2018-11-09\r\n" + "\r\n" + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" + ) + + body_as_str = ( + "--batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 0\r\n" + "\r\n" + "HTTP/1.1 202 Accepted\r\n" + "Content-Type: multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "\r\n" + "{}" + "--batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6--" + ).format(internal_body_as_str) + + response = MockResponse( + request, + body_as_str.encode('ascii'), + "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6" + ) + + parts = [] + async for part in response.parts(): + parts.append(part) + + assert len(parts) == 1 + + res0 = parts[0] + assert res0.status_code == 202 + + internal_parts = [] + async for part in res0.parts(): + internal_parts.append(part) + assert len(internal_parts) == 1 + + internal_response0 = internal_parts[0] + assert internal_response0.status_code == 400 diff --git a/sdk/core/azure-core/tests/test_basic_transport.py b/sdk/core/azure-core/tests/test_basic_transport.py new file mode 100644 index 000000000000..797c0589049d --- /dev/null +++ b/sdk/core/azure-core/tests/test_basic_transport.py @@ -0,0 +1,387 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from six.moves.http_client import HTTPConnection +from collections import OrderedDict +import time +import sys + +try: + from unittest import mock +except ImportError: + import mock + +from azure.core.pipeline.transport import HttpRequest, HttpResponse, RequestsTransport +from azure.core.pipeline.transport.base import HttpClientTransportResponse, HttpTransport, _deserialize_response +from azure.core.pipeline.policies import HeadersPolicy +from azure.core.pipeline import Pipeline + +import pytest + + +@pytest.mark.skipif(sys.version_info < (3, 6), reason="Multipart serialization not supported on 2.7 + dict order not deterministic on 3.5") +def test_http_request_serialization(): + # Method + Url + request = HttpRequest("DELETE", "/container0/blob0") + serialized = request.serialize() + + expected = ( + b'DELETE /container0/blob0 HTTP/1.1\r\n' + # No headers + b'\r\n' + ) + assert serialized == expected + + # Method + Url + Headers + request = HttpRequest( + "DELETE", + "/container0/blob0", + # Use OrderedDict to get consistent test result on 3.5 where order is not guaranted + headers=OrderedDict({ + "x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT", + "Authorization": "SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=", + "Content-Length": "0", + }) + ) + serialized = request.serialize() + + expected = ( + b'DELETE /container0/blob0 HTTP/1.1\r\n' + b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' + b'Authorization: SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=\r\n' + b'Content-Length: 0\r\n' + b'\r\n' + ) + assert serialized == expected + + + # Method + Url + Headers + Body + request = HttpRequest( + "DELETE", + "/container0/blob0", + headers={ + "x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT", + }, + ) + request.set_bytes_body(b"I am groot") + serialized = request.serialize() + + expected = ( + b'DELETE /container0/blob0 HTTP/1.1\r\n' + b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' + b'Content-Length: 10\r\n' + b'\r\n' + b'I am groot' + ) + assert serialized == expected + + +def test_http_client_response(): + # Create a core request + request = HttpRequest("GET", "www.httpbin.org") + + # Fake a transport based on http.client + conn = HTTPConnection("www.httpbin.org") + conn.request("GET", "/get") + r1 = conn.getresponse() + + response = HttpClientTransportResponse(request, r1) + + # Don't assume too much in those assert, since we reach a real server + assert response.internal_response is r1 + assert response.reason is not None + assert response.status_code == 200 + assert len(response.headers.keys()) != 0 + assert len(response.text()) != 0 + assert "content-type" in response.headers + assert "Content-Type" in response.headers + + +def test_response_deserialization(): + + # Method + Url + request = HttpRequest("DELETE", "/container0/blob0") + body = ( + b'HTTP/1.1 202 Accepted\r\n' + b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' + b'x-ms-version: 2018-11-09\r\n' + ) + + response = _deserialize_response(body, request) + + assert response.status_code == 202 + assert response.reason == "Accepted" + assert response.headers == { + 'x-ms-request-id': '778fdc83-801e-0000-62ff-0334671e284f', + 'x-ms-version': '2018-11-09' + } + + # Method + Url + Headers + Body + request = HttpRequest( + "DELETE", + "/container0/blob0", + headers={ + "x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT", + }, + ) + request.set_bytes_body(b"I am groot") + body = ( + b'HTTP/1.1 200 OK\r\n' + b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' + b'x-ms-version: 2018-11-09\r\n' + b'\r\n' + b'I am groot' + ) + + response = _deserialize_response(body, request) + + assert response.status_code == 200 + assert response.reason == "OK" + assert response.headers == { + 'x-ms-request-id': '778fdc83-801e-0000-62ff-0334671e284f', + 'x-ms-version': '2018-11-09' + } + assert response.text() == "I am groot" + +def test_response_deserialization_utf8_bom(): + + request = HttpRequest("DELETE", "/container0/blob0") + body = ( + b'HTTP/1.1 400 One of the request inputs is not valid.\r\n' + b'x-ms-error-code: InvalidInput\r\n' + b'x-ms-request-id: 5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\r\n' + b'x-ms-version: 2019-02-02\r\n' + b'Content-Length: 220\r\n' + b'Content-Type: application/xml\r\n' + b'Server: Windows-Azure-Blob/1.0\r\n' + b'\r\n' + b'\xef\xbb\xbf\nInvalidInputOne' + b'of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z' + ) + response = _deserialize_response(body, request) + assert response.body().startswith(b'\xef\xbb\xbf') + + +@pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") +def test_multipart_send(): + + transport = mock.MagicMock(spec=HttpTransport) + + header_policy = HeadersPolicy({ + 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' + }) + + req0 = HttpRequest("DELETE", "/container0/blob0") + req1 = HttpRequest("DELETE", "/container1/blob1") + + request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed( + req0, + req1, + policies=[header_policy], + boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" # Fix it so test are deterministic + ) + + with Pipeline(transport) as pipeline: + pipeline.run(request) + + assert request.body == ( + b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' + b'Content-Type: application/http\r\n' + b'Content-Transfer-Encoding: binary\r\n' + b'Content-ID: 0\r\n' + b'\r\n' + b'DELETE /container0/blob0 HTTP/1.1\r\n' + b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' + b'\r\n' + b'\r\n' + b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525\r\n' + b'Content-Type: application/http\r\n' + b'Content-Transfer-Encoding: binary\r\n' + b'Content-ID: 1\r\n' + b'\r\n' + b'DELETE /container1/blob1 HTTP/1.1\r\n' + b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' + b'\r\n' + b'\r\n' + b'--batch_357de4f7-6d0b-4e02-8cd2-6361411a9525--\r\n' + ) + + +def test_multipart_receive(): + + class MockResponse(HttpResponse): + def __init__(self, request, body, content_type): + super(MockResponse, self).__init__(request, None) + self._body = body + self.content_type = content_type + + def body(self): + return self._body + + class ResponsePolicy(object): + def on_response(self, request, response): + # type: (PipelineRequest, PipelineResponse) -> None + response.http_response.headers['x-ms-fun'] = 'true' + + req0 = HttpRequest("DELETE", "/container0/blob0") + req1 = HttpRequest("DELETE", "/container1/blob1") + + request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed( + req0, + req1, + policies=[ResponsePolicy()] + ) + + body_as_str = ( + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 0\r\n" + "\r\n" + "HTTP/1.1 202 Accepted\r\n" + "x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + "x-ms-version: 2018-11-09\r\n" + "\r\n" + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 2\r\n" + "\r\n" + "HTTP/1.1 404 The specified blob does not exist.\r\n" + "x-ms-error-code: BlobNotFound\r\n" + "x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e2852\r\n" + "x-ms-version: 2018-11-09\r\n" + "Content-Length: 216\r\n" + "Content-Type: application/xml\r\n" + "\r\n" + '\r\n' + "BlobNotFoundThe specified blob does not exist.\r\n" + "RequestId:778fdc83-801e-0000-62ff-0334671e2852\r\n" + "Time:2018-06-14T16:46:54.6040685Z\r\n" + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" + ) + + response = MockResponse( + request, + body_as_str.encode('ascii'), + "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + ) + + response = response.parts() + + assert len(response) == 2 + + res0 = response[0] + assert res0.status_code == 202 + assert res0.headers['x-ms-fun'] == 'true' + + res1 = response[1] + assert res1.status_code == 404 + assert res1.headers['x-ms-fun'] == 'true' + +def test_multipart_receive_with_bom(): + + req0 = HttpRequest("DELETE", "/container0/blob0") + + request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(req0) + + class MockResponse(HttpResponse): + def __init__(self, request, body, content_type): + super(MockResponse, self).__init__(request, None) + self._body = body + self.content_type = content_type + + def body(self): + return self._body + + body_as_bytes = ( + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\n" + b"Content-Type: application/http\n" + b"Content-Transfer-Encoding: binary\n" + b"Content-ID: 0\n" + b'\r\n' + b'HTTP/1.1 400 One of the request inputs is not valid.\r\n' + b'Content-Length: 220\r\n' + b'Content-Type: application/xml\r\n' + b'Server: Windows-Azure-Blob/1.0\r\n' + b'\r\n' + b'\xef\xbb\xbf\nInvalidInputOne' + b'of the request inputs is not valid.\nRequestId:5f3f9f2f-e01e-00cc-6eb1-6d00b5000000\nTime:2019-09-17T23:44:07.4671860Z\n' + b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" + ) + + response = MockResponse( + request, + body_as_bytes, + "multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed" + ) + + response = response.parts() + assert len(response) == 1 + + res0 = response[0] + assert res0.status_code == 400 + assert res0.body().startswith(b'\xef\xbb\xbf') + + +def test_recursive_multipart_receive(): + req0 = HttpRequest("DELETE", "/container0/blob0") + internal_req0 = HttpRequest("DELETE", "/container0/blob0") + req0.set_multipart_mixed(internal_req0) + + request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(req0) + + class MockResponse(HttpResponse): + def __init__(self, request, body, content_type): + super(MockResponse, self).__init__(request, None) + self._body = body + self.content_type = content_type + + def body(self): + return self._body + + internal_body_as_str = ( + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 0\r\n" + "\r\n" + "HTTP/1.1 400 Accepted\r\n" + "x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n" + "x-ms-version: 2018-11-09\r\n" + "\r\n" + "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed--" + ) + + body_as_str = ( + "--batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6\r\n" + "Content-Type: application/http\r\n" + "Content-ID: 0\r\n" + "\r\n" + "HTTP/1.1 202 Accepted\r\n" + "Content-Type: multipart/mixed; boundary=batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" + "\r\n" + "{}" + "--batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6--" + ).format(internal_body_as_str) + + response = MockResponse( + request, + body_as_str.encode('ascii'), + "multipart/mixed; boundary=batchresponse_8d5f5bcd-2cb5-44bb-91b5-e9a722e68cb6" + ) + + response = response.parts() + assert len(response) == 1 + + res0 = response[0] + assert res0.status_code == 202 + + internal_response = res0.parts() + assert len(internal_response) == 1 + + internal_response0 = internal_response[0] + assert internal_response0.status_code == 400