Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,32 @@ def __init__(
assert scope["type"] == "http"
self._receive = receive
self._send = send
self._stream_consumed = False
self._is_disconnected = False

# Not particularly graceful, but we store state around reading the request
# body in the ASGI scope, under the following...
#
# ['extensions']['starlette']['body']
# ['extensions']['starlette']['stream_consumed']
#
# This allows usages such as ASGI middleware to call the recieve and
# access the request body, and have that state persisted.
#
# Bit of an abuse of ASGI to take this approach. An alternate take would be
# that if you're going to use ASGI middleware it might be better to just
# accept the constraint that you *don't* get access to the request body in
# that context.
def _get_request_state(self, name: str, default: typing.Any = None) -> typing.Any:
return self.scope.get("extensions", {}).get("starlette", {}).get(name, default)

def _set_request_state(self, name: str, value: typing.Any) -> None:
if "extensions" not in self.scope:
self.scope["extensions"] = {"starlette": {name: value}}
elif "starlette" not in self.scope["extensions"]:
self.scope["extensions"]["starlette"] = {name: value}
else:
self.scope["extensions"]["starlette"][name] = value

@property
def method(self) -> str:
return self.scope["method"]
Expand All @@ -206,15 +229,17 @@ def receive(self) -> Receive:
return self._receive

async def stream(self) -> typing.AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body
body = self._get_request_state("body")
if body is not None:
yield body
yield b""
return

if self._stream_consumed:
stream_consumed = self._get_request_state("stream_consumed", default=False)
if stream_consumed:
raise RuntimeError("Stream consumed")

self._stream_consumed = True
self._set_request_state("stream_consumed", True)
while True:
message = await self._receive()
if message["type"] == "http.request":
Expand All @@ -229,12 +254,14 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
yield b""

async def body(self) -> bytes:
if not hasattr(self, "_body"):
body = self._get_request_state("body")
if body is None:
chunks: "typing.List[bytes]" = []
async for chunk in self.stream():
chunks.append(chunk)
self._body = b"".join(chunks)
return self._body
body = b"".join(chunks)
self._set_request_state("body", body)
return body

async def json(self) -> typing.Any:
if not hasattr(self, "_json"):
Expand Down
62 changes: 62 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,65 @@ async def app(scope, receive, send):
client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"json": "Send channel not available"}


def test_request_body_then_request_body(test_client_factory):
# If the request body is read, then ensure that instantiating a
# request a second time can return the content again.
async def app(scope, receive, send):
request = Request(scope, receive)
body = await request.body()
request2 = Request(scope, receive)
body2 = await request2.body()
response = JSONResponse({"body": body.decode(), "body2": body2.decode()})
await response(scope, receive, send)

client = test_client_factory(app)

response = client.post("/", data="abc")
assert response.json() == {"body": "abc", "body2": "abc"}


def test_request_stream_then_request_body(test_client_factory):
# If the request has been streamed, then ensure that instantiating a
# request a second time raises an exception when attempting to read content.
async def app(scope, receive, send):
request = Request(scope, receive)
chunks = b""
async for chunk in request.stream():
chunks += chunk

request2 = Request(scope, receive)
try:
body = await request2.body()
except RuntimeError:
body = b"<stream consumed>"

response = JSONResponse({"body": body.decode(), "stream": chunks.decode()})
await response(scope, receive, send)

client = test_client_factory(app)

response = client.post("/", data="abc")
assert response.json() == {"body": "<stream consumed>", "stream": "abc"}


def test_request_body_then_request_stream(test_client_factory):
# If the request body is read, then ensure that instantiating a
# request a second time can stream the content.
async def app(scope, receive, send):
request = Request(scope, receive)
body = await request.body()

request2 = Request(scope, receive)
chunks = b""
async for chunk in request2.stream():
chunks += chunk

response = JSONResponse({"body": body.decode(), "stream": chunks.decode()})
await response(scope, receive, send)

client = test_client_factory(app)

response = client.post("/", data="abc")
assert response.json() == {"body": "abc", "stream": "abc"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert response.json() == {"body": "abc", "stream": "abc"}
assert response.json() == {"body": "abc", "stream": "abc"}
def test_request_body_then_replace_body(test_client_factory):
import zlib
data = "Hello, world!"
gzipped_payload = zlib.compress(data.encode())
def ungzip_middleware(app):
async def wrapped_app(scope, receive, send):
dec = zlib.decompressobj() # offset 16 to skip the header
buffer = bytearray()
async def wrapped_rcv():
more_body = True
while more_body:
msg = await receive()
more_body = msg.get("more_body", False)
buffer.extend(msg["body"])
if len(buffer) > 16 or not more_body:
msg["body"] = dec.decompress(buffer)
yield msg
buffer.clear()
await app(scope, wrapped_rcv().__anext__, send)
return wrapped_app
def log_request_body_middleware(app):
async def wrapped_app(scope, receive, send):
await Request(scope, receive).body()
await app(scope, receive, send)
return wrapped_app
async def app(scope, receive, send):
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
await response(scope, receive, send)
app = log_request_body_middleware(ungzip_middleware(app))
client = test_client_factory(app)
response = client.post("/", data=gzipped_payload)
assert response.json() == {"body": data}

@gnat this is what I'm referring to

ungzip_middleware is a perfectly valid ASGI middleware. It could even be inside of a mounted ASGI app or something like that. If we implement this PR we completely break this sort of thing and it breaks in a very non-intuitive way. We're essentially dealing with cache invalidation