diff --git a/docs/responses.md b/docs/responses.md index 595c9101c..9aaf24e1f 100644 --- a/docs/responses.md +++ b/docs/responses.md @@ -154,7 +154,7 @@ class App: ### StreamingResponse -Takes an async generator and streams the response body. +Takes an async generator or a normal generator/iterator and streams the response body. ```python from starlette.responses import StreamingResponse @@ -180,6 +180,8 @@ class App: await response(receive, send) ``` +Have in mind that file-like objects (like those created by `open()`) are normal iterators. So, you can return them directly in a `StreamingResponse`. + ### FileResponse Asynchronously streams a file as the response. diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 35b589956..dd9860aaf 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,6 +1,7 @@ import asyncio import functools import typing +from typing import Any, AsyncGenerator, Iterator try: import contextvars # Python 3.7+ only. @@ -22,3 +23,24 @@ async def run_in_threadpool( # loop.run_in_executor doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) return await loop.run_in_executor(None, func, *args) + + +class _StopSyncIteration(Exception): + pass + + +def _interceptable_next(iterator: Iterator) -> Any: + try: + result = next(iterator) + return result + except StopIteration: + raise _StopSyncIteration + + +async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator: + while True: + try: + result = await run_in_threadpool(_interceptable_next, iterator) + yield result + except _StopSyncIteration: + break diff --git a/starlette/responses.py b/starlette/responses.py index 1fee57cc0..48afc47ad 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -1,5 +1,6 @@ import hashlib import http.cookies +import inspect import json import os import stat @@ -9,6 +10,7 @@ from urllib.parse import quote_plus from starlette.background import BackgroundTask +from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders from starlette.types import Receive, Scope, Send @@ -175,7 +177,10 @@ def __init__( media_type: str = None, background: BackgroundTask = None, ) -> None: - self.body_iterator = content + if inspect.isasyncgen(content): + self.body_iterator = content + else: + self.body_iterator = iterate_in_threadpool(content) self.status_code = status_code self.media_type = self.media_type if media_type is None else media_type self.background = background diff --git a/tests/test_responses.py b/tests/test_responses.py index 300975afe..3d5de413f 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -90,6 +90,23 @@ async def numbers_for_cleanup(start=1, stop=5): assert filled_by_bg_task == "6, 7, 8, 9" +def test_sync_streaming_response(): + async def app(scope, receive, send): + def numbers(minimum, maximum): + for i in range(minimum, maximum + 1): + yield str(i) + if i != maximum: + yield ", " + + generator = numbers(1, 5) + response = StreamingResponse(generator, media_type="text/plain") + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.text == "1, 2, 3, 4, 5" + + def test_response_headers(): async def app(scope, receive, send): headers = {"x-header-1": "123", "x-header-2": "456"}