diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 7a2d0c81144d..64328d902433 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -2,8 +2,9 @@ ## 1.12.0 (Unreleased) -### Features +### Bug Fixes +- Raise exception rather than swallowing it if there is something wrong in retry stream downloading #16723 - Added `azure.core.messaging.CloudEvent` model that follows the cloud event spec. - Added `azure.core.serialization.NULL` sentinel value diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index 9468ac9d3756..7b0e643ff8ed 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -39,7 +39,7 @@ from azure.core.exceptions import ServiceRequestError, ServiceResponseError from azure.core.pipeline import Pipeline -from ._base import HttpRequest +from ._base import HttpRequest, parse_range_header, make_range_header from ._base_async import ( AsyncHttpTransport, AsyncHttpResponse, @@ -195,7 +195,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR return response -class AioHttpStreamDownloadGenerator(AsyncIterator): +class AioHttpStreamDownloadGenerator(AsyncIterator): # pylint: disable=too-many-instance-attributes """Streams the response body data. :param pipeline: The pipeline object @@ -208,48 +208,78 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None: self.request = response.request self.response = response self.block_size = response.block_size - self.content_length = int(response.internal_response.headers.get('Content-Length', 0)) self.downloaded = 0 + headers = response.internal_response.headers + self.content_length = int(headers.get('Content-Length', 0)) + transfer_header = headers.get('Transfer-Encoding', '') + self._compressed = 'compress' in transfer_header or 'deflate' in transfer_header or 'gzip' in transfer_header + if "x-ms-range" in headers: + self.range_header = "x-ms-range" # type: Optional[str] + self.range = parse_range_header(headers["x-ms-range"]) + elif "Range" in headers: + self.range_header = "Range" + self.range = parse_range_header(headers["Range"]) + else: + self.range_header = None + self.etag = headers.get('etag') def __len__(self): return self.content_length - async def __anext__(self): + async def __anext__(self): # pylint:disable=too-many-statements retry_active = True retry_total = 3 retry_interval = 1 # 1 second - while retry_active: - try: - chunk = await self.response.internal_response.content.read(self.block_size) - if not chunk: - raise _ResponseStopIteration() - self.downloaded += self.block_size - return chunk - except _ResponseStopIteration: + try: + chunk = await self.response.internal_response.content.read(self.block_size) + if not chunk: self.response.internal_response.close() - raise StopAsyncIteration() - except (ChunkedEncodingError, ConnectionError): + raise _ResponseStopIteration() + self.downloaded += self.block_size + return chunk + except _ResponseStopIteration: + raise StopAsyncIteration() + except (ChunkedEncodingError, ConnectionError) as ex: + if self._compressed: + raise ex + while retry_active: retry_total -= 1 if retry_total <= 0: - retry_active = False + _LOGGER.warning("Unable to stream download: %s", ex) + raise ex + if not self.etag: + _LOGGER.warning("Unable to stream download: %s", ex) + raise ex + await asyncio.sleep(retry_interval) + headers = self.request.headers.copy() + if not self.range_header: + range_header = {'range': 'bytes=' + str(self.downloaded) + '-'} else: - await asyncio.sleep(retry_interval) - headers = {'range': 'bytes=' + str(self.downloaded) + '-'} + range_header = {self.range_header: make_range_header(self.range, self.downloaded)} + range_header.update({'If-Match': self.etag}) + headers.update(range_header) + try: resp = await self.pipeline.run(self.request, stream=True, headers=headers) - if resp.http_response.status_code == 416: - raise + if not resp.http_response: + continue + if resp.http_response.status_code == 412: + continue + self.response = resp.http_response chunk = await self.response.internal_response.content.read(self.block_size) if not chunk: + self.response.internal_response.close() raise StopAsyncIteration() - self.downloaded += len(chunk) + self.downloaded += self.block_size return chunk - continue - except StreamConsumedError: - raise - except Exception as err: - _LOGGER.warning("Unable to stream download: %s", err) - self.response.internal_response.close() - raise + except StopAsyncIteration: + raise StopAsyncIteration() + except Exception: # pylint: disable=broad-except + continue + except StreamConsumedError: + raise + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + raise class AioHttpTransportResponse(AsyncHttpResponse): """Methods for accessing response body data. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index c77212d33e69..d3e4e041f267 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -963,3 +963,25 @@ def options(self, url, params=None, headers=None, **kwargs): "OPTIONS", url, params, headers, content, form_content, None ) return request + + +def parse_range_header(header_value): + range_value = header_value.strip() + if not range_value.startswith("bytes="): + raise ValueError("Invalid header") + range_str = range_value[6:] + ret = range_str.split("-") + if len(ret) < 2: + raise ValueError("Invalid header") + start = int(ret[0]) if ret[0] else -1 + end = int(ret[1]) if ret[1] else -1 + return (start, end) + +def make_range_header(original_range, downloaded_size=0): + if original_range[0] == -1: + end = original_range[1] - downloaded_size + return "bytes=-" + str(end) + start = original_range[0] + downloaded_size + if original_range[1] == -1: + return "bytes=" + str(start) + "-" + return "bytes=" + str(start) + "-" + str(original_range[1]) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py index 087a373b0732..156ef5a39cc1 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py @@ -37,7 +37,7 @@ ServiceResponseError ) from azure.core.pipeline import Pipeline -from ._base import HttpRequest +from ._base import HttpRequest, parse_range_header, make_range_header from ._base_async import ( AsyncHttpResponse, _ResponseStopIteration, @@ -133,7 +133,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: return AsyncioRequestsTransportResponse(request, response, self.connection_config.data_block_size) -class AsyncioStreamDownloadGenerator(AsyncIterator): +class AsyncioStreamDownloadGenerator(AsyncIterator): # pylint: disable=too-many-instance-attributes """Streams the response body data. :param pipeline: The pipeline object @@ -147,59 +147,86 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None: self.response = response self.block_size = response.block_size self.iter_content_func = self.response.internal_response.iter_content(self.block_size) - self.content_length = int(response.headers.get('Content-Length', 0)) self.downloaded = 0 + headers = response.internal_response.headers + self.content_length = int(headers.get('Content-Length', 0)) + transfer_header = headers.get('Transfer-Encoding', '') + self._compressed = 'compress' in transfer_header or 'deflate' in transfer_header or 'gzip' in transfer_header + if "x-ms-range" in headers: + self.range_header = "x-ms-range" # type: Optional[str] + self.range = parse_range_header(headers["x-ms-range"]) + elif "Range" in headers: + self.range_header = "Range" + self.range = parse_range_header(headers["Range"]) + else: + self.range_header = None + self.etag = headers.get('etag') def __len__(self): return self.content_length - async def __anext__(self): + async def __anext__(self): # pylint:disable=too-many-statements loop = _get_running_loop() retry_active = True retry_total = 3 retry_interval = 1 # 1 second - while retry_active: - try: - chunk = await loop.run_in_executor( - None, - _iterate_response_content, - self.iter_content_func, - ) - if not chunk: - raise _ResponseStopIteration() - self.downloaded += self.block_size - return chunk - except _ResponseStopIteration: - self.response.internal_response.close() - raise StopAsyncIteration() - except (requests.exceptions.ChunkedEncodingError, - requests.exceptions.ConnectionError): + try: + chunk = await loop.run_in_executor( + None, + _iterate_response_content, + self.iter_content_func, + ) + if not chunk: + raise _ResponseStopIteration() + self.downloaded += self.block_size + return chunk + except _ResponseStopIteration: + raise StopAsyncIteration() + except (requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError) as ex: + if self._compressed: + raise ex + while retry_active: retry_total -= 1 if retry_total <= 0: - retry_active = False + _LOGGER.warning("Unable to stream download: %s", ex) + raise ex + if not self.etag: + _LOGGER.warning("Unable to stream download: %s", ex) + raise ex + await asyncio.sleep(retry_interval) + headers = self.request.headers.copy() + if not self.range_header: + range_header = {'range': 'bytes=' + str(self.downloaded) + '-'} else: - await asyncio.sleep(retry_interval) - headers = {'range': 'bytes=' + str(self.downloaded) + '-'} - resp = self.pipeline.run(self.request, stream=True, headers=headers) - if resp.status_code == 416: - raise + range_header = {self.range_header: make_range_header(self.range, self.downloaded)} + range_header.update({'If-Match': self.etag}) + headers.update(range_header) + try: + resp = await self.pipeline.run(self.request, stream=True, headers=headers) + if not resp.http_response: + continue + if resp.http_response.status_code == 412: + continue + self.response = resp.http_response chunk = await loop.run_in_executor( None, _iterate_response_content, self.iter_content_func, ) if not chunk: - raise StopIteration() - self.downloaded += len(chunk) + raise StopAsyncIteration() + self.downloaded += self.block_size return chunk - continue - except requests.exceptions.StreamConsumedError: - raise - except Exception as err: - _LOGGER.warning("Unable to stream download: %s", err) - self.response.internal_response.close() - raise - + except StopAsyncIteration: + raise StopAsyncIteration() + except Exception: # pylint: disable=broad-except + continue + except requests.exceptions.StreamConsumedError: + raise + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + raise class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore """Asynchronous streaming of data from the response. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index a6d8582422bd..780e04c83798 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -41,7 +41,9 @@ from ._base import ( HttpTransport, HttpResponse, - _HttpResponseBase + _HttpResponseBase, + make_range_header, + parse_range_header, ) from ._bigger_block_size_http_adapters import BiggerBlockSizeHTTPAdapter @@ -49,7 +51,6 @@ _LOGGER = logging.getLogger(__name__) - class _RequestsTransportResponseBase(_HttpResponseBase): """Base class for accessing response data. @@ -94,7 +95,7 @@ def text(self, encoding=None): return self.internal_response.text -class StreamDownloadGenerator(object): +class StreamDownloadGenerator(object): # pylint: disable=too-many-instance-attributes """Generator for streaming response data. :param pipeline: The pipeline object @@ -106,8 +107,20 @@ def __init__(self, pipeline, response): self.response = response self.block_size = response.block_size self.iter_content_func = self.response.internal_response.iter_content(self.block_size) - self.content_length = int(response.headers.get('Content-Length', 0)) self.downloaded = 0 + headers = response.internal_response.headers + self.content_length = int(headers.get('Content-Length', 0)) + transfer_header = headers.get('Transfer-Encoding', '') + self._compressed = 'compress' in transfer_header or 'deflate' in transfer_header or 'gzip' in transfer_header + if "x-ms-range" in headers: + self.range_header = "x-ms-range" + self.range = parse_range_header(headers["x-ms-range"]) + elif "Range" in headers: + self.range_header = "Range" + self.range = parse_range_header(headers["Range"]) + else: + self.range_header = None + self.etag = headers.get('etag') def __len__(self): return self.content_length @@ -115,43 +128,62 @@ def __len__(self): def __iter__(self): return self - def __next__(self): + def __next__(self): # pylint:disable=too-many-statements retry_active = True retry_total = 3 retry_interval = 1 # 1 second - while retry_active: - try: - chunk = next(self.iter_content_func) - if not chunk: - raise StopIteration() - self.downloaded += self.block_size - return chunk - except StopIteration: + try: + chunk = next(self.iter_content_func) + if not chunk: self.response.internal_response.close() raise StopIteration() - except (requests.exceptions.ChunkedEncodingError, - requests.exceptions.ConnectionError): + self.downloaded += self.block_size + return chunk + except StopIteration: + raise StopIteration() + except (requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError) as ex: + if self._compressed: + raise ex + while retry_active: retry_total -= 1 if retry_total <= 0: - retry_active = False + _LOGGER.warning("Unable to stream download: %s", ex) + raise ex + if not self.etag: + _LOGGER.warning("Unable to stream download: %s", ex) + raise ex + time.sleep(retry_interval) + headers = self.request.headers.copy() + if not self.range_header: + range_header = {'range': 'bytes=' + str(self.downloaded) + '-'} else: - time.sleep(retry_interval) - headers = {'range': 'bytes=' + str(self.downloaded) + '-'} + range_header = {self.range_header: make_range_header(self.range, self.downloaded)} + range_header.update({'If-Match': self.etag}) + headers.update(range_header) + try: resp = self.pipeline.run(self.request, stream=True, headers=headers) - if resp.http_response.status_code == 416: - raise + if not resp.http_response: + continue + if resp.http_response.status_code == 412: + raise ex + self.response = resp.http_response + self.iter_content_func = self.response.internal_response.iter_content(self.block_size) chunk = next(self.iter_content_func) if not chunk: + self.response.internal_response.close() raise StopIteration() - self.downloaded += len(chunk) + self.downloaded += self.block_size return chunk - continue - except requests.exceptions.StreamConsumedError: - raise - except Exception as err: - _LOGGER.warning("Unable to stream download: %s", err) - self.response.internal_response.close() - raise + except StopIteration: + raise StopIteration() + except Exception: # pylint: disable=broad-except + continue + except requests.exceptions.StreamConsumedError: + raise + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + raise next = __next__ # Python 2 compatibility. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py index bfc28735df5c..5ab660a60e41 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py @@ -37,7 +37,7 @@ ServiceResponseError ) from azure.core.pipeline import Pipeline -from ._base import HttpRequest +from ._base import HttpRequest, parse_range_header, make_range_header from ._base_async import ( AsyncHttpResponse, _ResponseStopIteration, @@ -49,7 +49,7 @@ _LOGGER = logging.getLogger(__name__) -class TrioStreamDownloadGenerator(AsyncIterator): +class TrioStreamDownloadGenerator(AsyncIterator): # pylint: disable=too-many-instance-attributes """Generator for streaming response data. :param pipeline: The pipeline object @@ -61,45 +61,71 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None: self.response = response self.block_size = response.block_size self.iter_content_func = self.response.internal_response.iter_content(self.block_size) - self.content_length = int(response.headers.get('Content-Length', 0)) self.downloaded = 0 + headers = response.internal_response.headers + self.content_length = int(headers.get('Content-Length', 0)) + transfer_header = headers.get('Transfer-Encoding', '') + self._compressed = 'compress' in transfer_header or 'deflate' in transfer_header or 'gzip' in transfer_header + if "x-ms-range" in headers: + self.range_header = "x-ms-range" # type: Optional[str] + self.range = parse_range_header(headers["x-ms-range"]) + elif "Range" in headers: + self.range_header = "Range" + self.range = parse_range_header(headers["Range"]) + else: + self.range_header = None + self.etag = headers.get('etag') def __len__(self): return self.content_length - async def __anext__(self): + async def __anext__(self): # pylint:disable=too-many-statements retry_active = True retry_total = 3 - while retry_active: + retry_interval = 1 # 1 second + try: try: - try: - chunk = await trio.to_thread.run_sync( - _iterate_response_content, - self.iter_content_func, - ) - except AttributeError: # trio < 0.12.1 - chunk = await trio.run_sync_in_worker_thread( # pylint: disable=no-member - _iterate_response_content, - self.iter_content_func, - ) - if not chunk: - raise _ResponseStopIteration() - self.downloaded += self.block_size - return chunk - except _ResponseStopIteration: - self.response.internal_response.close() - raise StopAsyncIteration() - except (requests.exceptions.ChunkedEncodingError, - requests.exceptions.ConnectionError): + chunk = await trio.to_thread.run_sync( + _iterate_response_content, + self.iter_content_func, + ) + except AttributeError: # trio < 0.12.1 + chunk = await trio.run_sync_in_worker_thread( # pylint: disable=no-member + _iterate_response_content, + self.iter_content_func, + ) + if not chunk: + raise _ResponseStopIteration() + self.downloaded += self.block_size + return chunk + except _ResponseStopIteration: + raise StopAsyncIteration() + except (requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError) as ex: + if self._compressed: + raise ex + while retry_active: retry_total -= 1 if retry_total <= 0: - retry_active = False + _LOGGER.warning("Unable to stream download: %s", ex) + raise ex + if not self.etag: + _LOGGER.warning("Unable to stream download: %s", ex) + raise ex + await trio.sleep(retry_interval) + headers = self.request.headers.copy() + if not self.range_header: + range_header = {'range': 'bytes=' + str(self.downloaded) + '-'} else: - await trio.sleep(1) - headers = {'range': 'bytes=' + str(self.downloaded) + '-'} + range_header = {self.range_header: make_range_header(self.range, self.downloaded)} + range_header.update({'If-Match': self.etag}) + headers.update(range_header) + try: resp = self.pipeline.run(self.request, stream=True, headers=headers) - if resp.status_code == 416: - raise + if not resp.http_response: + continue + if resp.http_response.status_code == 412: + continue try: chunk = await trio.to_thread.run_sync( _iterate_response_content, @@ -111,16 +137,18 @@ async def __anext__(self): self.iter_content_func, ) if not chunk: - raise StopIteration() - self.downloaded += len(chunk) + raise StopAsyncIteration() + self.downloaded += self.block_size return chunk - continue - except requests.exceptions.StreamConsumedError: - raise - except Exception as err: - _LOGGER.warning("Unable to stream download: %s", err) - self.response.internal_response.close() - raise + except StopAsyncIteration: + raise StopAsyncIteration() + except Exception: # pylint: disable=broad-except + continue + except requests.exceptions.StreamConsumedError: + raise + except Exception as err: + _LOGGER.warning("Unable to stream download: %s", err) + raise class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore """Asynchronous streaming of data from the response. diff --git a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py index f7368159615f..486b570c6afe 100644 --- a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py @@ -2,12 +2,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import requests from azure.core.pipeline.transport import ( HttpRequest, AsyncHttpResponse, AsyncHttpTransport, + AsyncioRequestsTransportResponse, + AioHttpTransport, ) -from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline import AsyncPipeline, PipelineResponse from azure.core.pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator from unittest import mock import pytest @@ -15,8 +18,8 @@ @pytest.mark.asyncio async def test_connection_error_response(): class MockTransport(AsyncHttpTransport): - def __init__(self): - self._count = 0 + def __init__(self, error=True): + self._error = error async def __aexit__(self, exc_type, exc_val, exc_tb): pass @@ -29,22 +32,24 @@ async def send(self, request, **kwargs): request = HttpRequest('GET', 'http://127.0.0.1/') response = AsyncHttpResponse(request, None) response.status_code = 200 + response.internal_response = MockInternalResponse(error=False) return response class MockContent(): - def __init__(self): - self._first = True + def __init__(self, error=True): + self._error = error async def read(self, block_size): - if self._first: - self._first = False + if self._error: raise ConnectionError return None class MockInternalResponse(): - def __init__(self): - self.headers = {} - self.content = MockContent() + def __init__(self, error=True): + self.headers = {"etag": "etag"} + self._error = error + self.content = MockContent(error=self._error) + self.status_code = 200 async def close(self): pass @@ -90,7 +95,7 @@ def __init__(self): self.headers = {} self.content = MockContent() - async def close(self): + def close(self): pass class AsyncMock(mock.MagicMock): @@ -105,3 +110,50 @@ async def __call__(self, *args, **kwargs): with mock.patch('asyncio.sleep', new_callable=AsyncMock): with pytest.raises(ConnectionError): await stream.__anext__() + +@pytest.mark.asyncio +async def test_response_streaming_error_behavior(): + # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 + block_size = 103 + total_response_size = 500 + req_response = requests.Response() + req_request = requests.Request() + + class FakeStreamWithConnectionError: + # fake object for urllib3.response.HTTPResponse + + def stream(self, chunk_size, decode_content=False): + assert chunk_size == block_size + left = total_response_size + while left > 0: + if left <= block_size: + raise requests.exceptions.ConnectionError() + data = b"X" * min(chunk_size, left) + left -= len(data) + yield data + + def close(self): + pass + + req_response.raw = FakeStreamWithConnectionError() + + response = AsyncioRequestsTransportResponse( + req_request, + req_response, + block_size, + ) + + async def mock_run(self, *args, **kwargs): + return PipelineResponse( + None, + requests.Response(), + None, + ) + + transport = AioHttpTransport() + pipeline = AsyncPipeline(transport) + pipeline.run = mock_run + downloader = response.stream_download(pipeline) + with pytest.raises(requests.exceptions.ConnectionError): + while True: + await downloader.__anext__() diff --git a/sdk/core/azure-core/tests/test_range_header.py b/sdk/core/azure-core/tests/test_range_header.py new file mode 100644 index 000000000000..c156ad19607f --- /dev/null +++ b/sdk/core/azure-core/tests/test_range_header.py @@ -0,0 +1,60 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from azure.core.pipeline.transport._requests_basic import parse_range_header, make_range_header + +def test_basic_range_header(): + range_header = "bytes=0-500" + range = parse_range_header(range_header) + assert range == (0,500) + +def test_no_start_range_header(): + range_header = "bytes=-500" + range = parse_range_header(range_header) + assert range == (-1,500) + +def test_no_end_range_header(): + range_header = "bytes=0-" + range = parse_range_header(range_header) + assert range == (0,-1) + +def test_make_basic_range_header(): + range_header = "bytes=0-500" + range = parse_range_header(range_header) + header = make_range_header(range, 100) + assert header == "bytes=100-500" + +def test_make_no_start_range_header(): + range_header = "bytes=-500" + range = parse_range_header(range_header) + header = make_range_header(range, 100) + assert header == "bytes=-400" + +def test_make_no_end_range_header(): + range_header = "bytes=0-" + range = parse_range_header(range_header) + header = make_range_header(range, 100) + assert header == "bytes=100-" diff --git a/sdk/core/azure-core/tests/test_stream_generator.py b/sdk/core/azure-core/tests/test_stream_generator.py index 4f40bba885eb..65aaeb19f94d 100644 --- a/sdk/core/azure-core/tests/test_stream_generator.py +++ b/sdk/core/azure-core/tests/test_stream_generator.py @@ -7,6 +7,8 @@ HttpRequest, HttpResponse, HttpTransport, + RequestsTransport, + RequestsTransportResponse, ) from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.transport._requests_basic import StreamDownloadGenerator @@ -18,8 +20,8 @@ def test_connection_error_response(): class MockTransport(HttpTransport): - def __init__(self): - self._count = 0 + def __init__(self, error=True): + self._error = error def __exit__(self, exc_type, exc_val, exc_tb): pass @@ -32,19 +34,25 @@ def send(self, request, **kwargs): request = HttpRequest('GET', 'http://127.0.0.1/') response = HttpResponse(request, None) response.status_code = 200 + response.internal_response = MockInternalResponse(error=False) return response def next(self): self.__next__() def __next__(self): - if self._count == 0: - self._count += 1 + if self._error: raise requests.exceptions.ConnectionError + return None class MockInternalResponse(): + def __init__(self, error=True): + self._error = error + self.status_code = 200 + self.headers = {"etag":"etag"} + def iter_content(self, block_size): - return MockTransport() + return MockTransport(error=self._error) def close(self): pass @@ -52,7 +60,7 @@ def close(self): http_request = HttpRequest('GET', 'http://127.0.0.1/') pipeline = Pipeline(MockTransport()) http_response = HttpResponse(http_request, None) - http_response.internal_response = MockInternalResponse() + http_response.internal_response = MockInternalResponse(error=True) stream = StreamDownloadGenerator(pipeline, http_response) with mock.patch('time.sleep', return_value=None): with pytest.raises(StopIteration): @@ -85,6 +93,9 @@ def __next__(self): raise requests.exceptions.ConnectionError class MockInternalResponse(): + def __init__(self): + self.headers = {} + def iter_content(self, block_size): return MockTransport() @@ -98,4 +109,49 @@ def close(self): stream = StreamDownloadGenerator(pipeline, http_response) with mock.patch('time.sleep', return_value=None): with pytest.raises(requests.exceptions.ConnectionError): - stream.__next__() \ No newline at end of file + stream.__next__() + +def test_response_streaming_error_behavior(): + # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 + block_size = 103 + total_response_size = 500 + req_response = requests.Response() + req_request = requests.Request() + + class FakeStreamWithConnectionError: + # fake object for urllib3.response.HTTPResponse + + def stream(self, chunk_size, decode_content=False): + assert chunk_size == block_size + left = total_response_size + while left > 0: + if left <= block_size: + raise requests.exceptions.ConnectionError() + data = b"X" * min(chunk_size, left) + left -= len(data) + yield data + + def close(self): + pass + + req_response.raw = FakeStreamWithConnectionError() + + response = RequestsTransportResponse( + req_request, + req_response, + block_size, + ) + + def mock_run(self, *args, **kwargs): + return PipelineResponse( + None, + requests.Response(), + None, + ) + + transport = RequestsTransport() + pipeline = Pipeline(transport) + pipeline.run = mock_run + downloader = response.stream_download(pipeline) + with pytest.raises(requests.exceptions.ConnectionError): + full_response = b"".join(downloader)