diff --git a/sdk/core/azure-core/HISTORY.md b/sdk/core/azure-core/HISTORY.md index aefe1cdfce9f..bf36d12524d5 100644 --- a/sdk/core/azure-core/HISTORY.md +++ b/sdk/core/azure-core/HISTORY.md @@ -7,6 +7,7 @@ ### Bug fixes +- Fix AsyncioRequestsTransport if input stream is an async generator #7743 - Fix form-data with aiohttp transport #7749 ## 2019-10-07 Version 1.0.0b4 diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/requests_asyncio.py index 2d237d69774f..6f0a9c2a833c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/requests_asyncio.py @@ -102,6 +102,16 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: loop = kwargs.get("loop", _get_running_loop()) response = None error = None # type: Optional[Union[ServiceRequestError, ServiceResponseError]] + if hasattr(request.data, '__aiter__'): + # Need to consume that async generator, since requests can't do anything with it + # That's not ideal, but a list is our only choice. Memory not optimal here, + # but providing an async generator to a requests based transport is not optimal too + new_data = [] + async for part in request.data: + new_data.append(part) + data_to_send = iter(new_data) + else: + data_to_send = request.data try: response = await loop.run_in_executor( None, @@ -110,7 +120,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse: request.method, request.url, headers=request.headers, - data=request.data, + data=data_to_send, files=request.files, verify=kwargs.pop('connection_verify', self.connection_config.verify), timeout=kwargs.pop('connection_timeout', self.connection_config.timeout), diff --git a/sdk/core/azure-core/tests/azure_core_asynctests/test_request_asyncio.py b/sdk/core/azure-core/tests/azure_core_asynctests/test_request_asyncio.py new file mode 100644 index 000000000000..666bdd39276e --- /dev/null +++ b/sdk/core/azure-core/tests/azure_core_asynctests/test_request_asyncio.py @@ -0,0 +1,40 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import json + +from azure.core.pipeline.transport import AsyncioRequestsTransport, HttpRequest + +import pytest + + +@pytest.mark.asyncio +async def test_async_gen_data(): + transport = AsyncioRequestsTransport() + + class AsyncGen: + def __init__(self): + self._range = iter([b"azerty"]) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._range) + except StopIteration: + raise StopAsyncIteration + + req = HttpRequest('GET', 'http://httpbin.org/post', data=AsyncGen()) + + await transport.send(req) + +@pytest.mark.asyncio +async def test_send_data(): + transport = AsyncioRequestsTransport() + req = HttpRequest('PUT', 'http://httpbin.org/anything', data=b"azerty") + response = await transport.send(req) + + assert json.loads(response.text())['data'] == "azerty" \ No newline at end of file