Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/responses.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class App:

### StreamingResponse

Takes an async generator and streams the response body.
Takes an async generator or a normal generator/iterator and streams the response body.

```python
from starlette.responses import StreamingResponse
Expand All @@ -180,6 +180,8 @@ class App:
await response(receive, send)
```

Have in mind that <a href="https://docs.python.org/3/glossary.html#term-file-like-object" target="_blank">file-like</a> objects (like those created by `open()`) are normal iterators. So, you can return them directly in a `StreamingResponse`.

### FileResponse

Asynchronously streams a file as the response.
Expand Down
22 changes: 22 additions & 0 deletions starlette/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import typing
from typing import Any, AsyncGenerator, Iterator

try:
import contextvars # Python 3.7+ only.
Expand All @@ -22,3 +23,24 @@ async def run_in_threadpool(
# loop.run_in_executor doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await loop.run_in_executor(None, func, *args)


class _StopSyncIteration(Exception):
pass


def _interceptable_next(iterator: Iterator) -> Any:
try:
result = next(iterator)
return result
except StopIteration:
raise _StopSyncIteration


async def iterator_to_async(iterator: Iterator) -> AsyncGenerator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about iterate_in_threadpool to mirror run_in_threadpool.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea. I struggled a lot to name it and wasn't quite convinced either.

while True:
try:
result = await run_in_threadpool(_interceptable_next, iterator)
yield result
except _StopSyncIteration:
break
7 changes: 6 additions & 1 deletion starlette/responses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import http.cookies
import inspect
import json
import os
import stat
Expand All @@ -9,6 +10,7 @@
from urllib.parse import quote_plus

from starlette.background import BackgroundTask
from starlette.concurrency import iterator_to_async
from starlette.datastructures import URL, MutableHeaders
from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -175,7 +177,10 @@ def __init__(
media_type: str = None,
background: BackgroundTask = None,
) -> None:
self.body_iterator = content
if inspect.isasyncgen(content):
self.body_iterator = content
else:
self.body_iterator = iterator_to_async(content)
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
Expand Down
18 changes: 18 additions & 0 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from starlette import status
from starlette.background import BackgroundTask
from starlette.concurrency import iterator_to_async
from starlette.requests import Request
from starlette.responses import (
FileResponse,
Expand Down Expand Up @@ -90,6 +91,23 @@ async def numbers_for_cleanup(start=1, stop=5):
assert filled_by_bg_task == "6, 7, 8, 9"


def test_sync_streaming_response():
async def app(scope, receive, send):
def numbers(minimum, maximum):
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "

generator = numbers(1, 5)
response = StreamingResponse(generator, media_type="text/plain")
await response(scope, receive, send)

client = TestClient(app)
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"


def test_response_headers():
async def app(scope, receive, send):
headers = {"x-header-1": "123", "x-header-2": "456"}
Expand Down