Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

## 1.12.0 (Unreleased)

### Features
### Bug Fixes

- Raise exception rather than swallowing it if there is something wrong in retry stream downloading #16723
- Added `azure.core.messaging.CloudEvent` model that follows the cloud event spec.
- Added `azure.core.serialization.NULL` sentinel value

Expand Down
84 changes: 57 additions & 27 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from azure.core.exceptions import ServiceRequestError, ServiceResponseError
from azure.core.pipeline import Pipeline

from ._base import HttpRequest
from ._base import HttpRequest, parse_range_header, make_range_header
from ._base_async import (
AsyncHttpTransport,
AsyncHttpResponse,
Expand Down Expand Up @@ -195,7 +195,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
return response


class AioHttpStreamDownloadGenerator(AsyncIterator):
class AioHttpStreamDownloadGenerator(AsyncIterator): # pylint: disable=too-many-instance-attributes
"""Streams the response body data.

:param pipeline: The pipeline object
Expand All @@ -208,48 +208,78 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
self.request = response.request
self.response = response
self.block_size = response.block_size
self.content_length = int(response.internal_response.headers.get('Content-Length', 0))
self.downloaded = 0
headers = response.internal_response.headers
self.content_length = int(headers.get('Content-Length', 0))
transfer_header = headers.get('Transfer-Encoding', '')
self._compressed = 'compress' in transfer_header or 'deflate' in transfer_header or 'gzip' in transfer_header
if "x-ms-range" in headers:
self.range_header = "x-ms-range" # type: Optional[str]
self.range = parse_range_header(headers["x-ms-range"])
elif "Range" in headers:
self.range_header = "Range"
self.range = parse_range_header(headers["Range"])
else:
self.range_header = None
self.etag = headers.get('etag')

def __len__(self):
return self.content_length

async def __anext__(self):
async def __anext__(self): # pylint:disable=too-many-statements
retry_active = True
retry_total = 3
retry_interval = 1 # 1 second
while retry_active:
try:
chunk = await self.response.internal_response.content.read(self.block_size)
if not chunk:
raise _ResponseStopIteration()
self.downloaded += self.block_size
return chunk
except _ResponseStopIteration:
try:
chunk = await self.response.internal_response.content.read(self.block_size)
if not chunk:
self.response.internal_response.close()
raise StopAsyncIteration()
except (ChunkedEncodingError, ConnectionError):
raise _ResponseStopIteration()
self.downloaded += self.block_size
return chunk
except _ResponseStopIteration:
raise StopAsyncIteration()
except (ChunkedEncodingError, ConnectionError) as ex:
if self._compressed:
raise ex
while retry_active:
retry_total -= 1
if retry_total <= 0:
retry_active = False
_LOGGER.warning("Unable to stream download: %s", ex)
raise ex
if not self.etag:
_LOGGER.warning("Unable to stream download: %s", ex)
raise ex
await asyncio.sleep(retry_interval)
headers = self.request.headers.copy()
if not self.range_header:
range_header = {'range': 'bytes=' + str(self.downloaded) + '-'}
else:
await asyncio.sleep(retry_interval)
headers = {'range': 'bytes=' + str(self.downloaded) + '-'}
range_header = {self.range_header: make_range_header(self.range, self.downloaded)}
range_header.update({'If-Match': self.etag})
headers.update(range_header)
try:
resp = await self.pipeline.run(self.request, stream=True, headers=headers)
if resp.http_response.status_code == 416:
raise
if not resp.http_response:
continue
if resp.http_response.status_code == 412:
continue
self.response = resp.http_response
chunk = await self.response.internal_response.content.read(self.block_size)
if not chunk:
self.response.internal_response.close()
raise StopAsyncIteration()
self.downloaded += len(chunk)
self.downloaded += self.block_size
return chunk
continue
except StreamConsumedError:
raise
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
self.response.internal_response.close()
raise
except StopAsyncIteration:
raise StopAsyncIteration()
except Exception: # pylint: disable=broad-except
continue
except StreamConsumedError:
raise
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
raise

class AioHttpTransportResponse(AsyncHttpResponse):
"""Methods for accessing response body data.
Expand Down
22 changes: 22 additions & 0 deletions sdk/core/azure-core/azure/core/pipeline/transport/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,25 @@ def options(self, url, params=None, headers=None, **kwargs):
"OPTIONS", url, params, headers, content, form_content, None
)
return request


def parse_range_header(header_value):
range_value = header_value.strip()
if not range_value.startswith("bytes="):
raise ValueError("Invalid header")
range_str = range_value[6:]
ret = range_str.split("-")
if len(ret) < 2:
raise ValueError("Invalid header")
start = int(ret[0]) if ret[0] else -1
end = int(ret[1]) if ret[1] else -1
return (start, end)

def make_range_header(original_range, downloaded_size=0):
if original_range[0] == -1:
end = original_range[1] - downloaded_size
return "bytes=-" + str(end)
start = original_range[0] + downloaded_size
if original_range[1] == -1:
return "bytes=" + str(start) + "-"
return "bytes=" + str(start) + "-" + str(original_range[1])
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ServiceResponseError
)
from azure.core.pipeline import Pipeline
from ._base import HttpRequest
from ._base import HttpRequest, parse_range_header, make_range_header
from ._base_async import (
AsyncHttpResponse,
_ResponseStopIteration,
Expand Down Expand Up @@ -133,7 +133,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse:
return AsyncioRequestsTransportResponse(request, response, self.connection_config.data_block_size)


class AsyncioStreamDownloadGenerator(AsyncIterator):
class AsyncioStreamDownloadGenerator(AsyncIterator): # pylint: disable=too-many-instance-attributes
"""Streams the response body data.

:param pipeline: The pipeline object
Expand All @@ -147,59 +147,86 @@ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
self.response = response
self.block_size = response.block_size
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
self.content_length = int(response.headers.get('Content-Length', 0))
self.downloaded = 0
headers = response.internal_response.headers
self.content_length = int(headers.get('Content-Length', 0))
transfer_header = headers.get('Transfer-Encoding', '')
self._compressed = 'compress' in transfer_header or 'deflate' in transfer_header or 'gzip' in transfer_header
if "x-ms-range" in headers:
self.range_header = "x-ms-range" # type: Optional[str]
self.range = parse_range_header(headers["x-ms-range"])
elif "Range" in headers:
self.range_header = "Range"
self.range = parse_range_header(headers["Range"])
else:
self.range_header = None
self.etag = headers.get('etag')

def __len__(self):
return self.content_length

async def __anext__(self):
async def __anext__(self): # pylint:disable=too-many-statements
loop = _get_running_loop()
retry_active = True
retry_total = 3
retry_interval = 1 # 1 second
while retry_active:
try:
chunk = await loop.run_in_executor(
None,
_iterate_response_content,
self.iter_content_func,
)
if not chunk:
raise _ResponseStopIteration()
self.downloaded += self.block_size
return chunk
except _ResponseStopIteration:
self.response.internal_response.close()
raise StopAsyncIteration()
except (requests.exceptions.ChunkedEncodingError,
requests.exceptions.ConnectionError):
try:
chunk = await loop.run_in_executor(
None,
_iterate_response_content,
self.iter_content_func,
)
if not chunk:
raise _ResponseStopIteration()
self.downloaded += self.block_size
return chunk
except _ResponseStopIteration:
raise StopAsyncIteration()
except (requests.exceptions.ChunkedEncodingError,
requests.exceptions.ConnectionError) as ex:
if self._compressed:
raise ex
while retry_active:
retry_total -= 1
if retry_total <= 0:
retry_active = False
_LOGGER.warning("Unable to stream download: %s", ex)
raise ex
if not self.etag:
_LOGGER.warning("Unable to stream download: %s", ex)
raise ex
await asyncio.sleep(retry_interval)
headers = self.request.headers.copy()
if not self.range_header:
range_header = {'range': 'bytes=' + str(self.downloaded) + '-'}
else:
await asyncio.sleep(retry_interval)
headers = {'range': 'bytes=' + str(self.downloaded) + '-'}
resp = self.pipeline.run(self.request, stream=True, headers=headers)
if resp.status_code == 416:
raise
range_header = {self.range_header: make_range_header(self.range, self.downloaded)}
range_header.update({'If-Match': self.etag})
headers.update(range_header)
try:
resp = await self.pipeline.run(self.request, stream=True, headers=headers)
if not resp.http_response:
continue
if resp.http_response.status_code == 412:
continue
self.response = resp.http_response
chunk = await loop.run_in_executor(
None,
_iterate_response_content,
self.iter_content_func,
)
if not chunk:
raise StopIteration()
self.downloaded += len(chunk)
raise StopAsyncIteration()
self.downloaded += self.block_size
return chunk
continue
except requests.exceptions.StreamConsumedError:
raise
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
self.response.internal_response.close()
raise

except StopAsyncIteration:
raise StopAsyncIteration()
except Exception: # pylint: disable=broad-except
continue
except requests.exceptions.StreamConsumedError:
raise
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
raise

class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore
"""Asynchronous streaming of data from the response.
Expand Down
Loading