Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdk/core/azure-core/azure/core/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def _prepare_multipart_mixed_request(request):
import concurrent.futures

def prepare_requests(req):
if req.multipart_mixed_info:
# Recursively update changeset "sub requests"
Pipeline._prepare_multipart_mixed_request(req)
context = PipelineContext(None, **pipeline_options)
pipeline_request = PipelineRequest(req, context)
for policy in policies:
Expand Down
3 changes: 3 additions & 0 deletions sdk/core/azure-core/azure/core/pipeline/_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ async def _prepare_multipart_mixed_request(self, request):
pipeline_options = multipart_mixed_info[3] # type: Dict[str, Any]

async def prepare_requests(req):
if req.multipart_mixed_info:
# Recursively update changeset "sub requests"
await self._prepare_multipart_mixed_request(req)
context = PipelineContext(None, **pipeline_options)
pipeline_request = PipelineRequest(req, context)
for policy in policies:
Expand Down
79 changes: 51 additions & 28 deletions sdk/core/azure-core/azure/core/pipeline/transport/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
Optional,
Tuple,
Iterator,
Type
)

from six.moves.http_client import HTTPConnection, HTTPResponse as _HTTPResponse
Expand Down Expand Up @@ -379,27 +380,30 @@ def set_multipart_mixed(self, *requests, **kwargs):

:keyword list[SansIOHTTPPolicy] policies: SansIOPolicy to apply at preparation time
:keyword str boundary: Optional boundary

:param requests: HttpRequests object
"""
self.multipart_mixed_info = (
requests,
kwargs.pop("policies", []),
kwargs.pop("boundary", []),
kwargs.pop("boundary", None),
kwargs
)

def prepare_multipart_body(self):
# type: () -> None
def prepare_multipart_body(self, content_index=0):
# type: (int) -> int
"""Will prepare the body of this request according to the multipart information.

This call assumes the on_request policies have been applied already in their
correct context (sync/async)

Does nothing if "set_multipart_mixed" was never called.

:param int content_index: The current index of parts within the batch message.
:returns: The updated index after all parts in this request have been added.
:rtype: int
"""
if not self.multipart_mixed_info:
return
return 0

requests = self.multipart_mixed_info[0] # type: List[HttpRequest]
boundary = self.multipart_mixed_info[2] # type: Optional[str]
Expand All @@ -409,12 +413,22 @@ def prepare_multipart_body(self):
main_message.add_header("Content-Type", "multipart/mixed")
if boundary:
main_message.set_boundary(boundary)
for i, req in enumerate(requests):

for req in requests:
part_message = Message()
part_message.add_header("Content-Type", "application/http")
part_message.add_header("Content-Transfer-Encoding", "binary")
part_message.add_header("Content-ID", str(i))
part_message.set_payload(req.serialize())
if req.multipart_mixed_info:
content_index = req.prepare_multipart_body(content_index=content_index)
part_message.add_header("Content-Type", req.headers['Content-Type'])
payload = req.serialize()
# We need to remove the ~HTTP/1.1 prefix along with the added content-length
payload = payload[payload.index(b'--'):]
else:
part_message.add_header("Content-Type", "application/http")
part_message.add_header("Content-Transfer-Encoding", "binary")
part_message.add_header("Content-ID", str(content_index))
payload = req.serialize()
content_index += 1
part_message.set_payload(payload)
main_message.attach(part_message)

try:
Expand All @@ -435,6 +449,7 @@ def prepare_multipart_body(self):
self.headers["Content-Type"] = (
"multipart/mixed; boundary=" + main_message.get_boundary()
)
return content_index

def serialize(self):
# type: () -> bytes
Expand Down Expand Up @@ -485,6 +500,31 @@ def text(self, encoding=None):
encoding = "utf-8-sig"
return self.body().decode(encoding)

def _decode_parts(self, message, http_response_type, requests):
# type: (Message, Type[_HttpResponseBase], List[HttpRequest]) -> List[HttpResponse]
"""Rebuild an HTTP response from pure string."""
responses = []
for index, raw_reponse in enumerate(message.get_payload()):
content_type = raw_reponse.get_content_type()
if content_type == "application/http":
responses.append(
_deserialize_response(
raw_reponse.get_payload(decode=True),
requests[index],
http_response_type=http_response_type,
)
)
elif content_type == "multipart/mixed" and requests[index].multipart_mixed_info:
# The message batch contains one or more change sets
changeset_requests = requests[index].multipart_mixed_info[0] # type: ignore
changeset_responses = self._decode_parts(raw_reponse, http_response_type, changeset_requests)
responses.extend(changeset_responses)
else:
raise ValueError(
"Multipart doesn't support part other than application/http for now"
)
return responses

def _get_raw_parts(self, http_response_type=None):
# type (Optional[Type[_HttpResponseBase]]) -> Iterator[HttpResponse]
"""Assuming this body is multipart, return the iterator or parts.
Expand All @@ -503,26 +543,9 @@ def _get_raw_parts(self, http_response_type=None):
+ b"\r\n\r\n"
+ body_as_bytes
)

message = message_parser(http_body) # type: Message

# Rebuild an HTTP response from pure string
requests = self.request.multipart_mixed_info[0] # type: List[HttpRequest]
responses = []
for request, raw_reponse in zip(requests, message.get_payload()):
if raw_reponse.get_content_type() == "application/http":
responses.append(
_deserialize_response(
raw_reponse.get_payload(decode=True),
request,
http_response_type=http_response_type,
)
)
else:
raise ValueError(
"Multipart doesn't support part other than application/http for now"
)
return responses
return self._decode_parts(message, http_response_type, requests)


class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
Expand Down
Loading